Add SNMP device discovery API

This commit is contained in:
Keith Smith
2026-05-23 20:19:38 -06:00
parent 0cbc6b6ea8
commit a38438e7f1
7 changed files with 635 additions and 17 deletions
+113
View File
@@ -0,0 +1,113 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app.api.credentials import SNMP_CREDENTIAL_TYPE
from app.auth.dependencies import require_role
from app.core.secrets import decrypt_secret
from app.db.session import get_db
from app.models import Credential, User
from app.schemas.core import (
SnmpDiscoveredInterfaceRead,
SnmpDiscoveryItemRead,
SnmpDiscoveryRead,
SnmpDiscoveryRequest,
)
from app.services.snmp import DiscoveredSnmpDevice, SnmpCredential, SnmpDiscoveryError, discover_snmp_device
router = APIRouter(prefix="/discovery", tags=["discovery"])
@router.post("/snmp", response_model=SnmpDiscoveryRead)
def discover_snmp(
payload: SnmpDiscoveryRequest,
_: User = Depends(require_role("admin")),
db: Session = Depends(get_db),
) -> SnmpDiscoveryRead:
profile = db.get(Credential, payload.credential_profile_id)
if profile is None or profile.credential_type != SNMP_CREDENTIAL_TYPE:
raise HTTPException(status_code=404, detail="SNMP credential profile not found")
extra = dict(profile.extra or {})
version = str(extra.get("version") or "2c")
if version != "2c":
raise HTTPException(status_code=400, detail="SNMP discovery currently supports SNMPv2c profiles")
community = decrypt_secret(profile.encrypted_secret)
if not community:
raise HTTPException(status_code=400, detail="SNMP credential profile has no usable community string")
credential = SnmpCredential(
community=community,
port=int(extra.get("port") or 161),
timeout_seconds=int(extra.get("timeout_seconds") or 5),
retries=int(extra.get("retries") or 1),
)
try:
discovered = discover_snmp_device(payload.host, credential)
except SnmpDiscoveryError as exc:
raise HTTPException(status_code=502, detail="SNMP discovery failed") from exc
return _discovery_to_read(payload.credential_profile_id, discovered)
def _discovery_to_read(credential_profile_id: int, discovered: DiscoveredSnmpDevice) -> SnmpDiscoveryRead:
interfaces = [
SnmpDiscoveredInterfaceRead(
index=interface.index,
name=interface.name,
description=interface.description,
admin_status=interface.admin_status,
oper_status=interface.oper_status,
speed_bps=interface.speed_bps,
)
for interface in discovered.interfaces
]
return SnmpDiscoveryRead(
host=discovered.host,
credential_profile_id=credential_profile_id,
device_name=discovered.device_name,
description=discovered.description,
uptime_seconds=discovered.uptime_seconds,
interfaces=interfaces,
monitorable_items=_monitorable_items(discovered),
)
def _monitorable_items(discovered: DiscoveredSnmpDevice) -> list[SnmpDiscoveryItemRead]:
items = [
SnmpDiscoveryItemRead(
item_id="device.uptime",
item_type="device_uptime",
group="Device Health",
label="Device uptime",
unit="seconds",
)
]
for interface in discovered.interfaces:
group = f"Interface {interface.name}"
item_prefix = f"interface.{interface.index}"
items.extend(
[
SnmpDiscoveryItemRead(
item_id=f"{item_prefix}.status",
item_type="interface_status",
group=group,
label=f"{interface.name} status",
),
SnmpDiscoveryItemRead(
item_id=f"{item_prefix}.traffic",
item_type="interface_traffic",
group=group,
label=f"{interface.name} traffic",
unit="bps",
),
SnmpDiscoveryItemRead(
item_id=f"{item_prefix}.errors",
item_type="interface_errors",
group=group,
label=f"{interface.name} errors and discards",
unit="count",
),
]
)
return items
+2 -1
View File
@@ -6,7 +6,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.exc import SQLAlchemyError
from app.api import alerts, assets, auth, credentials, health, monitors, notifications
from app.api import alerts, assets, auth, credentials, discovery, health, monitors, notifications
from app.core.config import settings
from app.db.session import SessionLocal
from app.services.bootstrap import ensure_initial_admin
@@ -46,3 +46,4 @@ app.include_router(monitors.router)
app.include_router(alerts.router)
app.include_router(notifications.router)
app.include_router(credentials.router)
app.include_router(discovery.router)
+33 -2
View File
@@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
@@ -211,3 +210,35 @@ class SnmpCredentialProfileRead(BaseModel):
has_secret: bool
created_at: datetime
updated_at: datetime
class SnmpDiscoveryRequest(BaseModel):
host: str = Field(min_length=1, max_length=255)
credential_profile_id: int
class SnmpDiscoveredInterfaceRead(BaseModel):
index: int
name: str
description: str | None
admin_status: str | None
oper_status: str | None
speed_bps: int | None
class SnmpDiscoveryItemRead(BaseModel):
item_id: str
item_type: str
group: str
label: str
unit: str | None = None
class SnmpDiscoveryRead(BaseModel):
host: str
credential_profile_id: int
device_name: str | None
description: str | None
uptime_seconds: int | None
interfaces: list[SnmpDiscoveredInterfaceRead]
monitorable_items: list[SnmpDiscoveryItemRead]
+328
View File
@@ -0,0 +1,328 @@
from dataclasses import dataclass
import random
import socket
from typing import Any
class SnmpDiscoveryError(Exception):
pass
@dataclass(frozen=True)
class SnmpCredential:
community: str
port: int = 161
timeout_seconds: int = 5
retries: int = 1
@dataclass(frozen=True)
class DiscoveredSnmpInterface:
index: int
name: str
description: str | None
admin_status: str | None
oper_status: str | None
speed_bps: int | None
@dataclass(frozen=True)
class DiscoveredSnmpDevice:
host: str
device_name: str | None
description: str | None
uptime_seconds: int | None
interfaces: list[DiscoveredSnmpInterface]
SYS_DESCR = (1, 3, 6, 1, 2, 1, 1, 1, 0)
SYS_UPTIME = (1, 3, 6, 1, 2, 1, 1, 3, 0)
SYS_NAME = (1, 3, 6, 1, 2, 1, 1, 5, 0)
IF_DESCR = (1, 3, 6, 1, 2, 1, 2, 2, 1, 2)
IF_SPEED = (1, 3, 6, 1, 2, 1, 2, 2, 1, 5)
IF_ADMIN_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 7)
IF_OPER_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 8)
IF_NAME = (1, 3, 6, 1, 2, 1, 31, 1, 1, 1, 1)
STATUS_LABELS = {
1: "up",
2: "down",
3: "testing",
4: "unknown",
5: "dormant",
6: "not present",
7: "lower layer down",
}
def discover_snmp_device(host: str, credential: SnmpCredential) -> DiscoveredSnmpDevice:
client = SnmpV2Client(host, credential)
system = client.get_many([SYS_NAME, SYS_DESCR, SYS_UPTIME])
interfaces = _discover_interfaces(client)
return DiscoveredSnmpDevice(
host=host,
device_name=_string_value(system.get(SYS_NAME)),
description=_string_value(system.get(SYS_DESCR)),
uptime_seconds=_timeticks_to_seconds(system.get(SYS_UPTIME)),
interfaces=interfaces,
)
def _discover_interfaces(client: "SnmpV2Client") -> list[DiscoveredSnmpInterface]:
names = client.walk(IF_NAME)
descriptions = client.walk(IF_DESCR)
admin_statuses = client.walk(IF_ADMIN_STATUS)
oper_statuses = client.walk(IF_OPER_STATUS)
speeds = client.walk(IF_SPEED)
indexes = sorted(
{
*_indexed_values(names).keys(),
*_indexed_values(descriptions).keys(),
*_indexed_values(admin_statuses).keys(),
*_indexed_values(oper_statuses).keys(),
*_indexed_values(speeds).keys(),
}
)
name_by_index = _indexed_values(names)
description_by_index = _indexed_values(descriptions)
admin_by_index = _indexed_values(admin_statuses)
oper_by_index = _indexed_values(oper_statuses)
speed_by_index = _indexed_values(speeds)
interfaces: list[DiscoveredSnmpInterface] = []
for index in indexes:
name = _string_value(name_by_index.get(index)) or _string_value(description_by_index.get(index)) or f"Interface {index}"
interfaces.append(
DiscoveredSnmpInterface(
index=index,
name=name,
description=_string_value(description_by_index.get(index)),
admin_status=_status_label(admin_by_index.get(index)),
oper_status=_status_label(oper_by_index.get(index)),
speed_bps=_int_value(speed_by_index.get(index)),
)
)
return interfaces
def _indexed_values(values: dict[tuple[int, ...], Any]) -> dict[int, Any]:
indexed: dict[int, Any] = {}
for oid, value in values.items():
if oid:
indexed[oid[-1]] = value
return indexed
def _string_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _int_value(value: Any) -> int | None:
if isinstance(value, int):
return value
return None
def _status_label(value: Any) -> str | None:
if not isinstance(value, int):
return None
return STATUS_LABELS.get(value, f"status {value}")
def _timeticks_to_seconds(value: Any) -> int | None:
if not isinstance(value, int):
return None
return int(value / 100)
class SnmpV2Client:
def __init__(self, host: str, credential: SnmpCredential) -> None:
self.host = host
self.credential = credential
def get_many(self, oids: list[tuple[int, ...]]) -> dict[tuple[int, ...], Any]:
return dict(self._request(0xA0, oids))
def walk(self, base_oid: tuple[int, ...], max_items: int = 128) -> dict[tuple[int, ...], Any]:
values: dict[tuple[int, ...], Any] = {}
next_oid = base_oid
for _ in range(max_items):
response = self._request(0xA1, [next_oid])
if not response:
break
returned_oid, value = response[0]
if not _oid_starts_with(returned_oid, base_oid):
break
values[returned_oid] = value
next_oid = returned_oid
return values
def _request(self, pdu_tag: int, oids: list[tuple[int, ...]]) -> list[tuple[tuple[int, ...], Any]]:
request_id = random.randint(1, 2_147_483_647)
packet = _encode_message(pdu_tag, request_id, self.credential.community, oids)
last_error: OSError | None = None
for _ in range(self.credential.retries + 1):
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.settimeout(self.credential.timeout_seconds)
sock.sendto(packet, (self.host, self.credential.port))
response, _ = sock.recvfrom(65535)
return _decode_response(response, request_id)
except OSError as exc:
last_error = exc
raise SnmpDiscoveryError(f"SNMP request failed for {self.host}") from last_error
def _oid_starts_with(oid: tuple[int, ...], base_oid: tuple[int, ...]) -> bool:
return oid[: len(base_oid)] == base_oid
def _encode_message(pdu_tag: int, request_id: int, community: str, oids: list[tuple[int, ...]]) -> bytes:
varbinds = b"".join(_sequence(_encode_oid(oid) + _tlv(0x05, b"")) for oid in oids)
pdu = _tlv(
pdu_tag,
_encode_integer(request_id)
+ _encode_integer(0)
+ _encode_integer(0)
+ _sequence(varbinds),
)
return _sequence(_encode_integer(1) + _tlv(0x04, community.encode("utf-8")) + pdu)
def _sequence(value: bytes) -> bytes:
return _tlv(0x30, value)
def _tlv(tag: int, value: bytes) -> bytes:
return bytes([tag]) + _encode_length(len(value)) + value
def _encode_length(length: int) -> bytes:
if length < 128:
return bytes([length])
encoded = length.to_bytes((length.bit_length() + 7) // 8, "big")
return bytes([0x80 | len(encoded)]) + encoded
def _encode_integer(value: int) -> bytes:
if value == 0:
return _tlv(0x02, b"\x00")
encoded = value.to_bytes((value.bit_length() + 7) // 8, "big")
if encoded[0] & 0x80:
encoded = b"\x00" + encoded
return _tlv(0x02, encoded)
def _encode_oid(oid: tuple[int, ...]) -> bytes:
if len(oid) < 2:
raise ValueError("OID must have at least two parts")
body = bytes([oid[0] * 40 + oid[1]])
for part in oid[2:]:
body += _encode_base128(part)
return _tlv(0x06, body)
def _encode_base128(value: int) -> bytes:
chunks = [value & 0x7F]
value >>= 7
while value:
chunks.insert(0, 0x80 | (value & 0x7F))
value >>= 7
return bytes(chunks)
def _decode_response(data: bytes, expected_request_id: int) -> list[tuple[tuple[int, ...], Any]]:
tag, message_value, _ = _read_tlv(data, 0)
if tag != 0x30:
raise SnmpDiscoveryError("SNMP response was not a sequence")
offset = 0
_, _, offset = _read_tlv(message_value, offset)
_, _, offset = _read_tlv(message_value, offset)
pdu_tag, pdu_value, offset = _read_tlv(message_value, offset)
if pdu_tag != 0xA2:
raise SnmpDiscoveryError("SNMP response was not a GetResponse")
pdu_offset = 0
_, request_id_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
if _decode_integer(request_id_value) != expected_request_id:
raise SnmpDiscoveryError("SNMP response request id did not match")
_, error_status_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
error_status = _decode_integer(error_status_value)
_, _, pdu_offset = _read_tlv(pdu_value, pdu_offset)
if error_status:
raise SnmpDiscoveryError(f"SNMP agent returned error status {error_status}")
varbind_list_tag, varbind_list_value, _ = _read_tlv(pdu_value, pdu_offset)
if varbind_list_tag != 0x30:
raise SnmpDiscoveryError("SNMP response did not include a varbind list")
responses: list[tuple[tuple[int, ...], Any]] = []
varbind_offset = 0
while varbind_offset < len(varbind_list_value):
varbind_tag, varbind_value, varbind_offset = _read_tlv(varbind_list_value, varbind_offset)
if varbind_tag != 0x30:
raise SnmpDiscoveryError("SNMP response included an invalid varbind")
oid_tag, oid_value, value_offset = _read_tlv(varbind_value, 0)
if oid_tag != 0x06:
raise SnmpDiscoveryError("SNMP varbind did not include an object identifier")
value_tag, value_value, _ = _read_tlv(varbind_value, value_offset)
responses.append((_decode_oid(oid_value), _decode_value(value_tag, value_value)))
return responses
def _read_tlv(data: bytes, offset: int) -> tuple[int, bytes, int]:
if offset >= len(data):
raise SnmpDiscoveryError("SNMP response ended unexpectedly")
tag = data[offset]
length, offset = _read_length(data, offset + 1)
end = offset + length
if end > len(data):
raise SnmpDiscoveryError("SNMP response length exceeded available data")
return tag, data[offset:end], end
def _read_length(data: bytes, offset: int) -> tuple[int, int]:
first = data[offset]
offset += 1
if first < 128:
return first, offset
byte_count = first & 0x7F
if byte_count == 0:
raise SnmpDiscoveryError("SNMP response used indefinite length")
return int.from_bytes(data[offset : offset + byte_count], "big"), offset + byte_count
def _decode_integer(value: bytes) -> int:
if not value:
return 0
return int.from_bytes(value, "big", signed=bool(value[0] & 0x80))
def _decode_oid(value: bytes) -> tuple[int, ...]:
if not value:
raise SnmpDiscoveryError("SNMP response included an empty object identifier")
oid = [value[0] // 40, value[0] % 40]
number = 0
for byte in value[1:]:
number = (number << 7) | (byte & 0x7F)
if not byte & 0x80:
oid.append(number)
number = 0
return tuple(oid)
def _decode_value(tag: int, value: bytes) -> Any:
if tag in {0x02, 0x41, 0x42, 0x43, 0x46}:
return int.from_bytes(value, "big")
if tag == 0x04:
return value.decode("utf-8", errors="replace")
if tag == 0x06:
return ".".join(str(part) for part in _decode_oid(value))
if tag in {0x05, 0x80, 0x81, 0x82}:
return None
return value