Add SNMP device discovery API
This commit is contained in:
@@ -0,0 +1,113 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.credentials import SNMP_CREDENTIAL_TYPE
|
||||
from app.auth.dependencies import require_role
|
||||
from app.core.secrets import decrypt_secret
|
||||
from app.db.session import get_db
|
||||
from app.models import Credential, User
|
||||
from app.schemas.core import (
|
||||
SnmpDiscoveredInterfaceRead,
|
||||
SnmpDiscoveryItemRead,
|
||||
SnmpDiscoveryRead,
|
||||
SnmpDiscoveryRequest,
|
||||
)
|
||||
from app.services.snmp import DiscoveredSnmpDevice, SnmpCredential, SnmpDiscoveryError, discover_snmp_device
|
||||
|
||||
router = APIRouter(prefix="/discovery", tags=["discovery"])
|
||||
|
||||
|
||||
@router.post("/snmp", response_model=SnmpDiscoveryRead)
|
||||
def discover_snmp(
|
||||
payload: SnmpDiscoveryRequest,
|
||||
_: User = Depends(require_role("admin")),
|
||||
db: Session = Depends(get_db),
|
||||
) -> SnmpDiscoveryRead:
|
||||
profile = db.get(Credential, payload.credential_profile_id)
|
||||
if profile is None or profile.credential_type != SNMP_CREDENTIAL_TYPE:
|
||||
raise HTTPException(status_code=404, detail="SNMP credential profile not found")
|
||||
|
||||
extra = dict(profile.extra or {})
|
||||
version = str(extra.get("version") or "2c")
|
||||
if version != "2c":
|
||||
raise HTTPException(status_code=400, detail="SNMP discovery currently supports SNMPv2c profiles")
|
||||
|
||||
community = decrypt_secret(profile.encrypted_secret)
|
||||
if not community:
|
||||
raise HTTPException(status_code=400, detail="SNMP credential profile has no usable community string")
|
||||
|
||||
credential = SnmpCredential(
|
||||
community=community,
|
||||
port=int(extra.get("port") or 161),
|
||||
timeout_seconds=int(extra.get("timeout_seconds") or 5),
|
||||
retries=int(extra.get("retries") or 1),
|
||||
)
|
||||
try:
|
||||
discovered = discover_snmp_device(payload.host, credential)
|
||||
except SnmpDiscoveryError as exc:
|
||||
raise HTTPException(status_code=502, detail="SNMP discovery failed") from exc
|
||||
|
||||
return _discovery_to_read(payload.credential_profile_id, discovered)
|
||||
|
||||
|
||||
def _discovery_to_read(credential_profile_id: int, discovered: DiscoveredSnmpDevice) -> SnmpDiscoveryRead:
|
||||
interfaces = [
|
||||
SnmpDiscoveredInterfaceRead(
|
||||
index=interface.index,
|
||||
name=interface.name,
|
||||
description=interface.description,
|
||||
admin_status=interface.admin_status,
|
||||
oper_status=interface.oper_status,
|
||||
speed_bps=interface.speed_bps,
|
||||
)
|
||||
for interface in discovered.interfaces
|
||||
]
|
||||
return SnmpDiscoveryRead(
|
||||
host=discovered.host,
|
||||
credential_profile_id=credential_profile_id,
|
||||
device_name=discovered.device_name,
|
||||
description=discovered.description,
|
||||
uptime_seconds=discovered.uptime_seconds,
|
||||
interfaces=interfaces,
|
||||
monitorable_items=_monitorable_items(discovered),
|
||||
)
|
||||
|
||||
|
||||
def _monitorable_items(discovered: DiscoveredSnmpDevice) -> list[SnmpDiscoveryItemRead]:
|
||||
items = [
|
||||
SnmpDiscoveryItemRead(
|
||||
item_id="device.uptime",
|
||||
item_type="device_uptime",
|
||||
group="Device Health",
|
||||
label="Device uptime",
|
||||
unit="seconds",
|
||||
)
|
||||
]
|
||||
for interface in discovered.interfaces:
|
||||
group = f"Interface {interface.name}"
|
||||
item_prefix = f"interface.{interface.index}"
|
||||
items.extend(
|
||||
[
|
||||
SnmpDiscoveryItemRead(
|
||||
item_id=f"{item_prefix}.status",
|
||||
item_type="interface_status",
|
||||
group=group,
|
||||
label=f"{interface.name} status",
|
||||
),
|
||||
SnmpDiscoveryItemRead(
|
||||
item_id=f"{item_prefix}.traffic",
|
||||
item_type="interface_traffic",
|
||||
group=group,
|
||||
label=f"{interface.name} traffic",
|
||||
unit="bps",
|
||||
),
|
||||
SnmpDiscoveryItemRead(
|
||||
item_id=f"{item_prefix}.errors",
|
||||
item_type="interface_errors",
|
||||
group=group,
|
||||
label=f"{interface.name} errors and discards",
|
||||
unit="count",
|
||||
),
|
||||
]
|
||||
)
|
||||
return items
|
||||
+2
-1
@@ -6,7 +6,7 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from app.api import alerts, assets, auth, credentials, health, monitors, notifications
|
||||
from app.api import alerts, assets, auth, credentials, discovery, health, monitors, notifications
|
||||
from app.core.config import settings
|
||||
from app.db.session import SessionLocal
|
||||
from app.services.bootstrap import ensure_initial_admin
|
||||
@@ -46,3 +46,4 @@ app.include_router(monitors.router)
|
||||
app.include_router(alerts.router)
|
||||
app.include_router(notifications.router)
|
||||
app.include_router(credentials.router)
|
||||
app.include_router(discovery.router)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -211,3 +210,35 @@ class SnmpCredentialProfileRead(BaseModel):
|
||||
has_secret: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class SnmpDiscoveryRequest(BaseModel):
|
||||
host: str = Field(min_length=1, max_length=255)
|
||||
credential_profile_id: int
|
||||
|
||||
|
||||
class SnmpDiscoveredInterfaceRead(BaseModel):
|
||||
index: int
|
||||
name: str
|
||||
description: str | None
|
||||
admin_status: str | None
|
||||
oper_status: str | None
|
||||
speed_bps: int | None
|
||||
|
||||
|
||||
class SnmpDiscoveryItemRead(BaseModel):
|
||||
item_id: str
|
||||
item_type: str
|
||||
group: str
|
||||
label: str
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
class SnmpDiscoveryRead(BaseModel):
|
||||
host: str
|
||||
credential_profile_id: int
|
||||
device_name: str | None
|
||||
description: str | None
|
||||
uptime_seconds: int | None
|
||||
interfaces: list[SnmpDiscoveredInterfaceRead]
|
||||
monitorable_items: list[SnmpDiscoveryItemRead]
|
||||
|
||||
@@ -0,0 +1,328 @@
|
||||
from dataclasses import dataclass
|
||||
import random
|
||||
import socket
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SnmpDiscoveryError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnmpCredential:
|
||||
community: str
|
||||
port: int = 161
|
||||
timeout_seconds: int = 5
|
||||
retries: int = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DiscoveredSnmpInterface:
|
||||
index: int
|
||||
name: str
|
||||
description: str | None
|
||||
admin_status: str | None
|
||||
oper_status: str | None
|
||||
speed_bps: int | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DiscoveredSnmpDevice:
|
||||
host: str
|
||||
device_name: str | None
|
||||
description: str | None
|
||||
uptime_seconds: int | None
|
||||
interfaces: list[DiscoveredSnmpInterface]
|
||||
|
||||
|
||||
SYS_DESCR = (1, 3, 6, 1, 2, 1, 1, 1, 0)
|
||||
SYS_UPTIME = (1, 3, 6, 1, 2, 1, 1, 3, 0)
|
||||
SYS_NAME = (1, 3, 6, 1, 2, 1, 1, 5, 0)
|
||||
IF_DESCR = (1, 3, 6, 1, 2, 1, 2, 2, 1, 2)
|
||||
IF_SPEED = (1, 3, 6, 1, 2, 1, 2, 2, 1, 5)
|
||||
IF_ADMIN_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 7)
|
||||
IF_OPER_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 8)
|
||||
IF_NAME = (1, 3, 6, 1, 2, 1, 31, 1, 1, 1, 1)
|
||||
|
||||
STATUS_LABELS = {
|
||||
1: "up",
|
||||
2: "down",
|
||||
3: "testing",
|
||||
4: "unknown",
|
||||
5: "dormant",
|
||||
6: "not present",
|
||||
7: "lower layer down",
|
||||
}
|
||||
|
||||
|
||||
def discover_snmp_device(host: str, credential: SnmpCredential) -> DiscoveredSnmpDevice:
|
||||
client = SnmpV2Client(host, credential)
|
||||
system = client.get_many([SYS_NAME, SYS_DESCR, SYS_UPTIME])
|
||||
interfaces = _discover_interfaces(client)
|
||||
return DiscoveredSnmpDevice(
|
||||
host=host,
|
||||
device_name=_string_value(system.get(SYS_NAME)),
|
||||
description=_string_value(system.get(SYS_DESCR)),
|
||||
uptime_seconds=_timeticks_to_seconds(system.get(SYS_UPTIME)),
|
||||
interfaces=interfaces,
|
||||
)
|
||||
|
||||
|
||||
def _discover_interfaces(client: "SnmpV2Client") -> list[DiscoveredSnmpInterface]:
|
||||
names = client.walk(IF_NAME)
|
||||
descriptions = client.walk(IF_DESCR)
|
||||
admin_statuses = client.walk(IF_ADMIN_STATUS)
|
||||
oper_statuses = client.walk(IF_OPER_STATUS)
|
||||
speeds = client.walk(IF_SPEED)
|
||||
|
||||
indexes = sorted(
|
||||
{
|
||||
*_indexed_values(names).keys(),
|
||||
*_indexed_values(descriptions).keys(),
|
||||
*_indexed_values(admin_statuses).keys(),
|
||||
*_indexed_values(oper_statuses).keys(),
|
||||
*_indexed_values(speeds).keys(),
|
||||
}
|
||||
)
|
||||
name_by_index = _indexed_values(names)
|
||||
description_by_index = _indexed_values(descriptions)
|
||||
admin_by_index = _indexed_values(admin_statuses)
|
||||
oper_by_index = _indexed_values(oper_statuses)
|
||||
speed_by_index = _indexed_values(speeds)
|
||||
|
||||
interfaces: list[DiscoveredSnmpInterface] = []
|
||||
for index in indexes:
|
||||
name = _string_value(name_by_index.get(index)) or _string_value(description_by_index.get(index)) or f"Interface {index}"
|
||||
interfaces.append(
|
||||
DiscoveredSnmpInterface(
|
||||
index=index,
|
||||
name=name,
|
||||
description=_string_value(description_by_index.get(index)),
|
||||
admin_status=_status_label(admin_by_index.get(index)),
|
||||
oper_status=_status_label(oper_by_index.get(index)),
|
||||
speed_bps=_int_value(speed_by_index.get(index)),
|
||||
)
|
||||
)
|
||||
return interfaces
|
||||
|
||||
|
||||
def _indexed_values(values: dict[tuple[int, ...], Any]) -> dict[int, Any]:
|
||||
indexed: dict[int, Any] = {}
|
||||
for oid, value in values.items():
|
||||
if oid:
|
||||
indexed[oid[-1]] = value
|
||||
return indexed
|
||||
|
||||
|
||||
def _string_value(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8", errors="replace")
|
||||
return str(value)
|
||||
|
||||
|
||||
def _int_value(value: Any) -> int | None:
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _status_label(value: Any) -> str | None:
|
||||
if not isinstance(value, int):
|
||||
return None
|
||||
return STATUS_LABELS.get(value, f"status {value}")
|
||||
|
||||
|
||||
def _timeticks_to_seconds(value: Any) -> int | None:
|
||||
if not isinstance(value, int):
|
||||
return None
|
||||
return int(value / 100)
|
||||
|
||||
|
||||
class SnmpV2Client:
|
||||
def __init__(self, host: str, credential: SnmpCredential) -> None:
|
||||
self.host = host
|
||||
self.credential = credential
|
||||
|
||||
def get_many(self, oids: list[tuple[int, ...]]) -> dict[tuple[int, ...], Any]:
|
||||
return dict(self._request(0xA0, oids))
|
||||
|
||||
def walk(self, base_oid: tuple[int, ...], max_items: int = 128) -> dict[tuple[int, ...], Any]:
|
||||
values: dict[tuple[int, ...], Any] = {}
|
||||
next_oid = base_oid
|
||||
for _ in range(max_items):
|
||||
response = self._request(0xA1, [next_oid])
|
||||
if not response:
|
||||
break
|
||||
returned_oid, value = response[0]
|
||||
if not _oid_starts_with(returned_oid, base_oid):
|
||||
break
|
||||
values[returned_oid] = value
|
||||
next_oid = returned_oid
|
||||
return values
|
||||
|
||||
def _request(self, pdu_tag: int, oids: list[tuple[int, ...]]) -> list[tuple[tuple[int, ...], Any]]:
|
||||
request_id = random.randint(1, 2_147_483_647)
|
||||
packet = _encode_message(pdu_tag, request_id, self.credential.community, oids)
|
||||
last_error: OSError | None = None
|
||||
for _ in range(self.credential.retries + 1):
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.settimeout(self.credential.timeout_seconds)
|
||||
sock.sendto(packet, (self.host, self.credential.port))
|
||||
response, _ = sock.recvfrom(65535)
|
||||
return _decode_response(response, request_id)
|
||||
except OSError as exc:
|
||||
last_error = exc
|
||||
raise SnmpDiscoveryError(f"SNMP request failed for {self.host}") from last_error
|
||||
|
||||
|
||||
def _oid_starts_with(oid: tuple[int, ...], base_oid: tuple[int, ...]) -> bool:
|
||||
return oid[: len(base_oid)] == base_oid
|
||||
|
||||
|
||||
def _encode_message(pdu_tag: int, request_id: int, community: str, oids: list[tuple[int, ...]]) -> bytes:
|
||||
varbinds = b"".join(_sequence(_encode_oid(oid) + _tlv(0x05, b"")) for oid in oids)
|
||||
pdu = _tlv(
|
||||
pdu_tag,
|
||||
_encode_integer(request_id)
|
||||
+ _encode_integer(0)
|
||||
+ _encode_integer(0)
|
||||
+ _sequence(varbinds),
|
||||
)
|
||||
return _sequence(_encode_integer(1) + _tlv(0x04, community.encode("utf-8")) + pdu)
|
||||
|
||||
|
||||
def _sequence(value: bytes) -> bytes:
|
||||
return _tlv(0x30, value)
|
||||
|
||||
|
||||
def _tlv(tag: int, value: bytes) -> bytes:
|
||||
return bytes([tag]) + _encode_length(len(value)) + value
|
||||
|
||||
|
||||
def _encode_length(length: int) -> bytes:
|
||||
if length < 128:
|
||||
return bytes([length])
|
||||
encoded = length.to_bytes((length.bit_length() + 7) // 8, "big")
|
||||
return bytes([0x80 | len(encoded)]) + encoded
|
||||
|
||||
|
||||
def _encode_integer(value: int) -> bytes:
|
||||
if value == 0:
|
||||
return _tlv(0x02, b"\x00")
|
||||
encoded = value.to_bytes((value.bit_length() + 7) // 8, "big")
|
||||
if encoded[0] & 0x80:
|
||||
encoded = b"\x00" + encoded
|
||||
return _tlv(0x02, encoded)
|
||||
|
||||
|
||||
def _encode_oid(oid: tuple[int, ...]) -> bytes:
|
||||
if len(oid) < 2:
|
||||
raise ValueError("OID must have at least two parts")
|
||||
body = bytes([oid[0] * 40 + oid[1]])
|
||||
for part in oid[2:]:
|
||||
body += _encode_base128(part)
|
||||
return _tlv(0x06, body)
|
||||
|
||||
|
||||
def _encode_base128(value: int) -> bytes:
|
||||
chunks = [value & 0x7F]
|
||||
value >>= 7
|
||||
while value:
|
||||
chunks.insert(0, 0x80 | (value & 0x7F))
|
||||
value >>= 7
|
||||
return bytes(chunks)
|
||||
|
||||
|
||||
def _decode_response(data: bytes, expected_request_id: int) -> list[tuple[tuple[int, ...], Any]]:
|
||||
tag, message_value, _ = _read_tlv(data, 0)
|
||||
if tag != 0x30:
|
||||
raise SnmpDiscoveryError("SNMP response was not a sequence")
|
||||
|
||||
offset = 0
|
||||
_, _, offset = _read_tlv(message_value, offset)
|
||||
_, _, offset = _read_tlv(message_value, offset)
|
||||
pdu_tag, pdu_value, offset = _read_tlv(message_value, offset)
|
||||
if pdu_tag != 0xA2:
|
||||
raise SnmpDiscoveryError("SNMP response was not a GetResponse")
|
||||
|
||||
pdu_offset = 0
|
||||
_, request_id_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
if _decode_integer(request_id_value) != expected_request_id:
|
||||
raise SnmpDiscoveryError("SNMP response request id did not match")
|
||||
_, error_status_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
error_status = _decode_integer(error_status_value)
|
||||
_, _, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
if error_status:
|
||||
raise SnmpDiscoveryError(f"SNMP agent returned error status {error_status}")
|
||||
varbind_list_tag, varbind_list_value, _ = _read_tlv(pdu_value, pdu_offset)
|
||||
if varbind_list_tag != 0x30:
|
||||
raise SnmpDiscoveryError("SNMP response did not include a varbind list")
|
||||
|
||||
responses: list[tuple[tuple[int, ...], Any]] = []
|
||||
varbind_offset = 0
|
||||
while varbind_offset < len(varbind_list_value):
|
||||
varbind_tag, varbind_value, varbind_offset = _read_tlv(varbind_list_value, varbind_offset)
|
||||
if varbind_tag != 0x30:
|
||||
raise SnmpDiscoveryError("SNMP response included an invalid varbind")
|
||||
oid_tag, oid_value, value_offset = _read_tlv(varbind_value, 0)
|
||||
if oid_tag != 0x06:
|
||||
raise SnmpDiscoveryError("SNMP varbind did not include an object identifier")
|
||||
value_tag, value_value, _ = _read_tlv(varbind_value, value_offset)
|
||||
responses.append((_decode_oid(oid_value), _decode_value(value_tag, value_value)))
|
||||
return responses
|
||||
|
||||
|
||||
def _read_tlv(data: bytes, offset: int) -> tuple[int, bytes, int]:
|
||||
if offset >= len(data):
|
||||
raise SnmpDiscoveryError("SNMP response ended unexpectedly")
|
||||
tag = data[offset]
|
||||
length, offset = _read_length(data, offset + 1)
|
||||
end = offset + length
|
||||
if end > len(data):
|
||||
raise SnmpDiscoveryError("SNMP response length exceeded available data")
|
||||
return tag, data[offset:end], end
|
||||
|
||||
|
||||
def _read_length(data: bytes, offset: int) -> tuple[int, int]:
|
||||
first = data[offset]
|
||||
offset += 1
|
||||
if first < 128:
|
||||
return first, offset
|
||||
byte_count = first & 0x7F
|
||||
if byte_count == 0:
|
||||
raise SnmpDiscoveryError("SNMP response used indefinite length")
|
||||
return int.from_bytes(data[offset : offset + byte_count], "big"), offset + byte_count
|
||||
|
||||
|
||||
def _decode_integer(value: bytes) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
return int.from_bytes(value, "big", signed=bool(value[0] & 0x80))
|
||||
|
||||
|
||||
def _decode_oid(value: bytes) -> tuple[int, ...]:
|
||||
if not value:
|
||||
raise SnmpDiscoveryError("SNMP response included an empty object identifier")
|
||||
oid = [value[0] // 40, value[0] % 40]
|
||||
number = 0
|
||||
for byte in value[1:]:
|
||||
number = (number << 7) | (byte & 0x7F)
|
||||
if not byte & 0x80:
|
||||
oid.append(number)
|
||||
number = 0
|
||||
return tuple(oid)
|
||||
|
||||
|
||||
def _decode_value(tag: int, value: bytes) -> Any:
|
||||
if tag in {0x02, 0x41, 0x42, 0x43, 0x46}:
|
||||
return int.from_bytes(value, "big")
|
||||
if tag == 0x04:
|
||||
return value.decode("utf-8", errors="replace")
|
||||
if tag == 0x06:
|
||||
return ".".join(str(part) for part in _decode_oid(value))
|
||||
if tag in {0x05, 0x80, 0x81, 0x82}:
|
||||
return None
|
||||
return value
|
||||
Reference in New Issue
Block a user