Add SNMP device discovery API

This commit is contained in:
Keith Smith
2026-05-23 20:19:38 -06:00
parent 0cbc6b6ea8
commit a38438e7f1
7 changed files with 635 additions and 17 deletions
+328
View File
@@ -0,0 +1,328 @@
from dataclasses import dataclass
import random
import socket
from typing import Any
class SnmpDiscoveryError(Exception):
pass
@dataclass(frozen=True)
class SnmpCredential:
community: str
port: int = 161
timeout_seconds: int = 5
retries: int = 1
@dataclass(frozen=True)
class DiscoveredSnmpInterface:
index: int
name: str
description: str | None
admin_status: str | None
oper_status: str | None
speed_bps: int | None
@dataclass(frozen=True)
class DiscoveredSnmpDevice:
host: str
device_name: str | None
description: str | None
uptime_seconds: int | None
interfaces: list[DiscoveredSnmpInterface]
SYS_DESCR = (1, 3, 6, 1, 2, 1, 1, 1, 0)
SYS_UPTIME = (1, 3, 6, 1, 2, 1, 1, 3, 0)
SYS_NAME = (1, 3, 6, 1, 2, 1, 1, 5, 0)
IF_DESCR = (1, 3, 6, 1, 2, 1, 2, 2, 1, 2)
IF_SPEED = (1, 3, 6, 1, 2, 1, 2, 2, 1, 5)
IF_ADMIN_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 7)
IF_OPER_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 8)
IF_NAME = (1, 3, 6, 1, 2, 1, 31, 1, 1, 1, 1)
STATUS_LABELS = {
1: "up",
2: "down",
3: "testing",
4: "unknown",
5: "dormant",
6: "not present",
7: "lower layer down",
}
def discover_snmp_device(host: str, credential: SnmpCredential) -> DiscoveredSnmpDevice:
client = SnmpV2Client(host, credential)
system = client.get_many([SYS_NAME, SYS_DESCR, SYS_UPTIME])
interfaces = _discover_interfaces(client)
return DiscoveredSnmpDevice(
host=host,
device_name=_string_value(system.get(SYS_NAME)),
description=_string_value(system.get(SYS_DESCR)),
uptime_seconds=_timeticks_to_seconds(system.get(SYS_UPTIME)),
interfaces=interfaces,
)
def _discover_interfaces(client: "SnmpV2Client") -> list[DiscoveredSnmpInterface]:
names = client.walk(IF_NAME)
descriptions = client.walk(IF_DESCR)
admin_statuses = client.walk(IF_ADMIN_STATUS)
oper_statuses = client.walk(IF_OPER_STATUS)
speeds = client.walk(IF_SPEED)
indexes = sorted(
{
*_indexed_values(names).keys(),
*_indexed_values(descriptions).keys(),
*_indexed_values(admin_statuses).keys(),
*_indexed_values(oper_statuses).keys(),
*_indexed_values(speeds).keys(),
}
)
name_by_index = _indexed_values(names)
description_by_index = _indexed_values(descriptions)
admin_by_index = _indexed_values(admin_statuses)
oper_by_index = _indexed_values(oper_statuses)
speed_by_index = _indexed_values(speeds)
interfaces: list[DiscoveredSnmpInterface] = []
for index in indexes:
name = _string_value(name_by_index.get(index)) or _string_value(description_by_index.get(index)) or f"Interface {index}"
interfaces.append(
DiscoveredSnmpInterface(
index=index,
name=name,
description=_string_value(description_by_index.get(index)),
admin_status=_status_label(admin_by_index.get(index)),
oper_status=_status_label(oper_by_index.get(index)),
speed_bps=_int_value(speed_by_index.get(index)),
)
)
return interfaces
def _indexed_values(values: dict[tuple[int, ...], Any]) -> dict[int, Any]:
indexed: dict[int, Any] = {}
for oid, value in values.items():
if oid:
indexed[oid[-1]] = value
return indexed
def _string_value(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, bytes):
return value.decode("utf-8", errors="replace")
return str(value)
def _int_value(value: Any) -> int | None:
if isinstance(value, int):
return value
return None
def _status_label(value: Any) -> str | None:
if not isinstance(value, int):
return None
return STATUS_LABELS.get(value, f"status {value}")
def _timeticks_to_seconds(value: Any) -> int | None:
if not isinstance(value, int):
return None
return int(value / 100)
class SnmpV2Client:
def __init__(self, host: str, credential: SnmpCredential) -> None:
self.host = host
self.credential = credential
def get_many(self, oids: list[tuple[int, ...]]) -> dict[tuple[int, ...], Any]:
return dict(self._request(0xA0, oids))
def walk(self, base_oid: tuple[int, ...], max_items: int = 128) -> dict[tuple[int, ...], Any]:
values: dict[tuple[int, ...], Any] = {}
next_oid = base_oid
for _ in range(max_items):
response = self._request(0xA1, [next_oid])
if not response:
break
returned_oid, value = response[0]
if not _oid_starts_with(returned_oid, base_oid):
break
values[returned_oid] = value
next_oid = returned_oid
return values
def _request(self, pdu_tag: int, oids: list[tuple[int, ...]]) -> list[tuple[tuple[int, ...], Any]]:
request_id = random.randint(1, 2_147_483_647)
packet = _encode_message(pdu_tag, request_id, self.credential.community, oids)
last_error: OSError | None = None
for _ in range(self.credential.retries + 1):
try:
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.settimeout(self.credential.timeout_seconds)
sock.sendto(packet, (self.host, self.credential.port))
response, _ = sock.recvfrom(65535)
return _decode_response(response, request_id)
except OSError as exc:
last_error = exc
raise SnmpDiscoveryError(f"SNMP request failed for {self.host}") from last_error
def _oid_starts_with(oid: tuple[int, ...], base_oid: tuple[int, ...]) -> bool:
return oid[: len(base_oid)] == base_oid
def _encode_message(pdu_tag: int, request_id: int, community: str, oids: list[tuple[int, ...]]) -> bytes:
varbinds = b"".join(_sequence(_encode_oid(oid) + _tlv(0x05, b"")) for oid in oids)
pdu = _tlv(
pdu_tag,
_encode_integer(request_id)
+ _encode_integer(0)
+ _encode_integer(0)
+ _sequence(varbinds),
)
return _sequence(_encode_integer(1) + _tlv(0x04, community.encode("utf-8")) + pdu)
def _sequence(value: bytes) -> bytes:
return _tlv(0x30, value)
def _tlv(tag: int, value: bytes) -> bytes:
return bytes([tag]) + _encode_length(len(value)) + value
def _encode_length(length: int) -> bytes:
if length < 128:
return bytes([length])
encoded = length.to_bytes((length.bit_length() + 7) // 8, "big")
return bytes([0x80 | len(encoded)]) + encoded
def _encode_integer(value: int) -> bytes:
if value == 0:
return _tlv(0x02, b"\x00")
encoded = value.to_bytes((value.bit_length() + 7) // 8, "big")
if encoded[0] & 0x80:
encoded = b"\x00" + encoded
return _tlv(0x02, encoded)
def _encode_oid(oid: tuple[int, ...]) -> bytes:
if len(oid) < 2:
raise ValueError("OID must have at least two parts")
body = bytes([oid[0] * 40 + oid[1]])
for part in oid[2:]:
body += _encode_base128(part)
return _tlv(0x06, body)
def _encode_base128(value: int) -> bytes:
chunks = [value & 0x7F]
value >>= 7
while value:
chunks.insert(0, 0x80 | (value & 0x7F))
value >>= 7
return bytes(chunks)
def _decode_response(data: bytes, expected_request_id: int) -> list[tuple[tuple[int, ...], Any]]:
tag, message_value, _ = _read_tlv(data, 0)
if tag != 0x30:
raise SnmpDiscoveryError("SNMP response was not a sequence")
offset = 0
_, _, offset = _read_tlv(message_value, offset)
_, _, offset = _read_tlv(message_value, offset)
pdu_tag, pdu_value, offset = _read_tlv(message_value, offset)
if pdu_tag != 0xA2:
raise SnmpDiscoveryError("SNMP response was not a GetResponse")
pdu_offset = 0
_, request_id_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
if _decode_integer(request_id_value) != expected_request_id:
raise SnmpDiscoveryError("SNMP response request id did not match")
_, error_status_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
error_status = _decode_integer(error_status_value)
_, _, pdu_offset = _read_tlv(pdu_value, pdu_offset)
if error_status:
raise SnmpDiscoveryError(f"SNMP agent returned error status {error_status}")
varbind_list_tag, varbind_list_value, _ = _read_tlv(pdu_value, pdu_offset)
if varbind_list_tag != 0x30:
raise SnmpDiscoveryError("SNMP response did not include a varbind list")
responses: list[tuple[tuple[int, ...], Any]] = []
varbind_offset = 0
while varbind_offset < len(varbind_list_value):
varbind_tag, varbind_value, varbind_offset = _read_tlv(varbind_list_value, varbind_offset)
if varbind_tag != 0x30:
raise SnmpDiscoveryError("SNMP response included an invalid varbind")
oid_tag, oid_value, value_offset = _read_tlv(varbind_value, 0)
if oid_tag != 0x06:
raise SnmpDiscoveryError("SNMP varbind did not include an object identifier")
value_tag, value_value, _ = _read_tlv(varbind_value, value_offset)
responses.append((_decode_oid(oid_value), _decode_value(value_tag, value_value)))
return responses
def _read_tlv(data: bytes, offset: int) -> tuple[int, bytes, int]:
if offset >= len(data):
raise SnmpDiscoveryError("SNMP response ended unexpectedly")
tag = data[offset]
length, offset = _read_length(data, offset + 1)
end = offset + length
if end > len(data):
raise SnmpDiscoveryError("SNMP response length exceeded available data")
return tag, data[offset:end], end
def _read_length(data: bytes, offset: int) -> tuple[int, int]:
first = data[offset]
offset += 1
if first < 128:
return first, offset
byte_count = first & 0x7F
if byte_count == 0:
raise SnmpDiscoveryError("SNMP response used indefinite length")
return int.from_bytes(data[offset : offset + byte_count], "big"), offset + byte_count
def _decode_integer(value: bytes) -> int:
if not value:
return 0
return int.from_bytes(value, "big", signed=bool(value[0] & 0x80))
def _decode_oid(value: bytes) -> tuple[int, ...]:
if not value:
raise SnmpDiscoveryError("SNMP response included an empty object identifier")
oid = [value[0] // 40, value[0] % 40]
number = 0
for byte in value[1:]:
number = (number << 7) | (byte & 0x7F)
if not byte & 0x80:
oid.append(number)
number = 0
return tuple(oid)
def _decode_value(tag: int, value: bytes) -> Any:
if tag in {0x02, 0x41, 0x42, 0x43, 0x46}:
return int.from_bytes(value, "big")
if tag == 0x04:
return value.decode("utf-8", errors="replace")
if tag == 0x06:
return ".".join(str(part) for part in _decode_oid(value))
if tag in {0x05, 0x80, 0x81, 0x82}:
return None
return value