Add SNMP monitor collection
This commit is contained in:
@@ -0,0 +1,354 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import random
|
||||
import socket
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
|
||||
class SnmpCheckError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnmpMetricValue:
|
||||
name: str
|
||||
value: float
|
||||
unit: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnmpCheckConfig:
|
||||
host: str
|
||||
community: str
|
||||
item_id: str
|
||||
item_type: str
|
||||
port: int = 161
|
||||
timeout_seconds: float = 5.0
|
||||
retries: int = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SnmpCheckResult:
|
||||
status: str
|
||||
response_time_ms: int | None
|
||||
message: str
|
||||
metrics: list[SnmpMetricValue] = field(default_factory=list)
|
||||
|
||||
|
||||
SYS_UPTIME = (1, 3, 6, 1, 2, 1, 1, 3, 0)
|
||||
IF_ADMIN_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 7)
|
||||
IF_OPER_STATUS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 8)
|
||||
IF_IN_OCTETS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 10)
|
||||
IF_IN_DISCARDS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 13)
|
||||
IF_IN_ERRORS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 14)
|
||||
IF_OUT_OCTETS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 16)
|
||||
IF_OUT_DISCARDS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 19)
|
||||
IF_OUT_ERRORS = (1, 3, 6, 1, 2, 1, 2, 2, 1, 20)
|
||||
IF_HC_IN_OCTETS = (1, 3, 6, 1, 2, 1, 31, 1, 1, 1, 6)
|
||||
IF_HC_OUT_OCTETS = (1, 3, 6, 1, 2, 1, 31, 1, 1, 1, 10)
|
||||
|
||||
STATUS_LABELS = {
|
||||
1: "up",
|
||||
2: "down",
|
||||
3: "testing",
|
||||
4: "unknown",
|
||||
5: "dormant",
|
||||
6: "not present",
|
||||
7: "lower layer down",
|
||||
}
|
||||
|
||||
|
||||
async def run_snmp_check(config: SnmpCheckConfig) -> SnmpCheckResult:
|
||||
try:
|
||||
return await asyncio.to_thread(_run_snmp_check_sync, config)
|
||||
except (OSError, SnmpCheckError) as exc:
|
||||
return SnmpCheckResult(status="down", response_time_ms=None, message=f"SNMP check failed: {exc}")
|
||||
|
||||
|
||||
def _run_snmp_check_sync(config: SnmpCheckConfig) -> SnmpCheckResult:
|
||||
started = perf_counter()
|
||||
client = SnmpV2Client(config.host, config.community, config.port, config.timeout_seconds, config.retries)
|
||||
|
||||
if config.item_type == "device_uptime":
|
||||
value = _int_value(client.get_many([SYS_UPTIME]).get(SYS_UPTIME))
|
||||
response_time_ms = int((perf_counter() - started) * 1000)
|
||||
if value is None:
|
||||
return SnmpCheckResult(status="down", response_time_ms=response_time_ms, message="Device uptime was not reported")
|
||||
uptime_seconds = int(value / 100)
|
||||
return SnmpCheckResult(
|
||||
status="up",
|
||||
response_time_ms=response_time_ms,
|
||||
message=f"Device uptime is {uptime_seconds} seconds",
|
||||
metrics=[SnmpMetricValue(name="uptime_seconds", value=float(uptime_seconds), unit="seconds")],
|
||||
)
|
||||
|
||||
interface_index = _interface_index(config.item_id)
|
||||
if interface_index is None:
|
||||
return SnmpCheckResult(status="down", response_time_ms=0, message="SNMP interface item was not valid")
|
||||
|
||||
if config.item_type == "interface_status":
|
||||
oids = [_with_index(IF_ADMIN_STATUS, interface_index), _with_index(IF_OPER_STATUS, interface_index)]
|
||||
values = client.get_many(oids)
|
||||
response_time_ms = int((perf_counter() - started) * 1000)
|
||||
admin_value = _int_value(values.get(oids[0]))
|
||||
oper_value = _int_value(values.get(oids[1]))
|
||||
if admin_value is None or oper_value is None:
|
||||
return SnmpCheckResult(status="down", response_time_ms=response_time_ms, message="Interface status was not reported")
|
||||
admin_status = STATUS_LABELS.get(admin_value, f"status {admin_value}")
|
||||
oper_status = STATUS_LABELS.get(oper_value, f"status {oper_value}")
|
||||
status = "up" if admin_value == 1 and oper_value == 1 else "down"
|
||||
return SnmpCheckResult(
|
||||
status=status,
|
||||
response_time_ms=response_time_ms,
|
||||
message=f"Interface admin {admin_status}, operational {oper_status}",
|
||||
metrics=[
|
||||
SnmpMetricValue(name="admin_status", value=float(admin_value)),
|
||||
SnmpMetricValue(name="oper_status", value=float(oper_value)),
|
||||
],
|
||||
)
|
||||
|
||||
if config.item_type == "interface_traffic":
|
||||
oids = [
|
||||
_with_index(IF_HC_IN_OCTETS, interface_index),
|
||||
_with_index(IF_HC_OUT_OCTETS, interface_index),
|
||||
_with_index(IF_IN_OCTETS, interface_index),
|
||||
_with_index(IF_OUT_OCTETS, interface_index),
|
||||
]
|
||||
values = client.get_many(oids)
|
||||
response_time_ms = int((perf_counter() - started) * 1000)
|
||||
in_octets = _int_value(values.get(oids[0])) or _int_value(values.get(oids[2]))
|
||||
out_octets = _int_value(values.get(oids[1])) or _int_value(values.get(oids[3]))
|
||||
if in_octets is None and out_octets is None:
|
||||
return SnmpCheckResult(status="down", response_time_ms=response_time_ms, message="Interface traffic counters were not reported")
|
||||
metrics = []
|
||||
if in_octets is not None:
|
||||
metrics.append(SnmpMetricValue(name="in_octets", value=float(in_octets), unit="bytes"))
|
||||
if out_octets is not None:
|
||||
metrics.append(SnmpMetricValue(name="out_octets", value=float(out_octets), unit="bytes"))
|
||||
return SnmpCheckResult(
|
||||
status="up",
|
||||
response_time_ms=response_time_ms,
|
||||
message="Interface traffic counters collected",
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
if config.item_type == "interface_errors":
|
||||
oids = [
|
||||
_with_index(IF_IN_ERRORS, interface_index),
|
||||
_with_index(IF_OUT_ERRORS, interface_index),
|
||||
_with_index(IF_IN_DISCARDS, interface_index),
|
||||
_with_index(IF_OUT_DISCARDS, interface_index),
|
||||
]
|
||||
values = client.get_many(oids)
|
||||
response_time_ms = int((perf_counter() - started) * 1000)
|
||||
metric_values = [
|
||||
("in_errors", _int_value(values.get(oids[0])), "count"),
|
||||
("out_errors", _int_value(values.get(oids[1])), "count"),
|
||||
("in_discards", _int_value(values.get(oids[2])), "count"),
|
||||
("out_discards", _int_value(values.get(oids[3])), "count"),
|
||||
]
|
||||
metrics = [SnmpMetricValue(name=name, value=float(value), unit=unit) for name, value, unit in metric_values if value is not None]
|
||||
if not metrics:
|
||||
return SnmpCheckResult(status="down", response_time_ms=response_time_ms, message="Interface error counters were not reported")
|
||||
return SnmpCheckResult(
|
||||
status="up",
|
||||
response_time_ms=response_time_ms,
|
||||
message="Interface error and discard counters collected",
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
return SnmpCheckResult(status="down", response_time_ms=0, message=f"Unsupported SNMP item type: {config.item_type}")
|
||||
|
||||
|
||||
def _interface_index(item_id: str) -> int | None:
|
||||
parts = item_id.split(".")
|
||||
if len(parts) < 3 or parts[0] != "interface":
|
||||
return None
|
||||
try:
|
||||
return int(parts[1])
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _with_index(oid: tuple[int, ...], index: int) -> tuple[int, ...]:
|
||||
return (*oid, index)
|
||||
|
||||
|
||||
def _int_value(value: Any) -> int | None:
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class SnmpV2Client:
|
||||
def __init__(self, host: str, community: str, port: int, timeout_seconds: float, retries: int) -> None:
|
||||
self.host = host
|
||||
self.community = community
|
||||
self.port = port
|
||||
self.timeout_seconds = timeout_seconds
|
||||
self.retries = retries
|
||||
|
||||
def get_many(self, oids: list[tuple[int, ...]]) -> dict[tuple[int, ...], Any]:
|
||||
return dict(self._request(0xA0, oids))
|
||||
|
||||
def _request(self, pdu_tag: int, oids: list[tuple[int, ...]]) -> list[tuple[tuple[int, ...], Any]]:
|
||||
request_id = random.randint(1, 2_147_483_647)
|
||||
packet = _encode_message(pdu_tag, request_id, self.community, oids)
|
||||
last_error: OSError | None = None
|
||||
for _ in range(self.retries + 1):
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.settimeout(self.timeout_seconds)
|
||||
sock.sendto(packet, (self.host, self.port))
|
||||
response, _ = sock.recvfrom(65535)
|
||||
return _decode_response(response, request_id)
|
||||
except OSError as exc:
|
||||
last_error = exc
|
||||
raise SnmpCheckError(f"SNMP request failed for {self.host}") from last_error
|
||||
|
||||
|
||||
def _encode_message(pdu_tag: int, request_id: int, community: str, oids: list[tuple[int, ...]]) -> bytes:
|
||||
varbinds = b"".join(_sequence(_encode_oid(oid) + _tlv(0x05, b"")) for oid in oids)
|
||||
pdu = _tlv(
|
||||
pdu_tag,
|
||||
_encode_integer(request_id)
|
||||
+ _encode_integer(0)
|
||||
+ _encode_integer(0)
|
||||
+ _sequence(varbinds),
|
||||
)
|
||||
return _sequence(_encode_integer(1) + _tlv(0x04, community.encode("utf-8")) + pdu)
|
||||
|
||||
|
||||
def _sequence(value: bytes) -> bytes:
|
||||
return _tlv(0x30, value)
|
||||
|
||||
|
||||
def _tlv(tag: int, value: bytes) -> bytes:
|
||||
return bytes([tag]) + _encode_length(len(value)) + value
|
||||
|
||||
|
||||
def _encode_length(length: int) -> bytes:
|
||||
if length < 128:
|
||||
return bytes([length])
|
||||
encoded = length.to_bytes((length.bit_length() + 7) // 8, "big")
|
||||
return bytes([0x80 | len(encoded)]) + encoded
|
||||
|
||||
|
||||
def _encode_integer(value: int) -> bytes:
|
||||
if value == 0:
|
||||
return _tlv(0x02, b"\x00")
|
||||
encoded = value.to_bytes((value.bit_length() + 7) // 8, "big")
|
||||
if encoded[0] & 0x80:
|
||||
encoded = b"\x00" + encoded
|
||||
return _tlv(0x02, encoded)
|
||||
|
||||
|
||||
def _encode_oid(oid: tuple[int, ...]) -> bytes:
|
||||
if len(oid) < 2:
|
||||
raise ValueError("OID must have at least two parts")
|
||||
body = bytes([oid[0] * 40 + oid[1]])
|
||||
for part in oid[2:]:
|
||||
body += _encode_base128(part)
|
||||
return _tlv(0x06, body)
|
||||
|
||||
|
||||
def _encode_base128(value: int) -> bytes:
|
||||
chunks = [value & 0x7F]
|
||||
value >>= 7
|
||||
while value:
|
||||
chunks.insert(0, 0x80 | (value & 0x7F))
|
||||
value >>= 7
|
||||
return bytes(chunks)
|
||||
|
||||
|
||||
def _decode_response(data: bytes, expected_request_id: int) -> list[tuple[tuple[int, ...], Any]]:
|
||||
tag, message_value, _ = _read_tlv(data, 0)
|
||||
if tag != 0x30:
|
||||
raise SnmpCheckError("SNMP response was not a sequence")
|
||||
|
||||
offset = 0
|
||||
_, _, offset = _read_tlv(message_value, offset)
|
||||
_, _, offset = _read_tlv(message_value, offset)
|
||||
pdu_tag, pdu_value, _ = _read_tlv(message_value, offset)
|
||||
if pdu_tag != 0xA2:
|
||||
raise SnmpCheckError("SNMP response was not a GetResponse")
|
||||
|
||||
pdu_offset = 0
|
||||
_, request_id_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
if _decode_integer(request_id_value) != expected_request_id:
|
||||
raise SnmpCheckError("SNMP response request id did not match")
|
||||
_, error_status_value, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
error_status = _decode_integer(error_status_value)
|
||||
_, _, pdu_offset = _read_tlv(pdu_value, pdu_offset)
|
||||
if error_status:
|
||||
raise SnmpCheckError(f"SNMP agent returned error status {error_status}")
|
||||
varbind_list_tag, varbind_list_value, _ = _read_tlv(pdu_value, pdu_offset)
|
||||
if varbind_list_tag != 0x30:
|
||||
raise SnmpCheckError("SNMP response did not include a varbind list")
|
||||
|
||||
responses: list[tuple[tuple[int, ...], Any]] = []
|
||||
varbind_offset = 0
|
||||
while varbind_offset < len(varbind_list_value):
|
||||
varbind_tag, varbind_value, varbind_offset = _read_tlv(varbind_list_value, varbind_offset)
|
||||
if varbind_tag != 0x30:
|
||||
raise SnmpCheckError("SNMP response included an invalid varbind")
|
||||
oid_tag, oid_value, value_offset = _read_tlv(varbind_value, 0)
|
||||
if oid_tag != 0x06:
|
||||
raise SnmpCheckError("SNMP varbind did not include an object identifier")
|
||||
value_tag, value_value, _ = _read_tlv(varbind_value, value_offset)
|
||||
responses.append((_decode_oid(oid_value), _decode_value(value_tag, value_value)))
|
||||
return responses
|
||||
|
||||
|
||||
def _read_tlv(data: bytes, offset: int) -> tuple[int, bytes, int]:
|
||||
if offset >= len(data):
|
||||
raise SnmpCheckError("SNMP response ended unexpectedly")
|
||||
tag = data[offset]
|
||||
length, offset = _read_length(data, offset + 1)
|
||||
end = offset + length
|
||||
if end > len(data):
|
||||
raise SnmpCheckError("SNMP response length exceeded available data")
|
||||
return tag, data[offset:end], end
|
||||
|
||||
|
||||
def _read_length(data: bytes, offset: int) -> tuple[int, int]:
|
||||
first = data[offset]
|
||||
offset += 1
|
||||
if first < 128:
|
||||
return first, offset
|
||||
byte_count = first & 0x7F
|
||||
if byte_count == 0:
|
||||
raise SnmpCheckError("SNMP response used indefinite length")
|
||||
return int.from_bytes(data[offset : offset + byte_count], "big"), offset + byte_count
|
||||
|
||||
|
||||
def _decode_integer(value: bytes) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
return int.from_bytes(value, "big", signed=bool(value[0] & 0x80))
|
||||
|
||||
|
||||
def _decode_oid(value: bytes) -> tuple[int, ...]:
|
||||
if not value:
|
||||
raise SnmpCheckError("SNMP response included an empty object identifier")
|
||||
oid = [value[0] // 40, value[0] % 40]
|
||||
number = 0
|
||||
for byte in value[1:]:
|
||||
number = (number << 7) | (byte & 0x7F)
|
||||
if not byte & 0x80:
|
||||
oid.append(number)
|
||||
number = 0
|
||||
return tuple(oid)
|
||||
|
||||
|
||||
def _decode_value(tag: int, value: bytes) -> Any:
|
||||
if tag in {0x02, 0x41, 0x42, 0x43, 0x46}:
|
||||
return int.from_bytes(value, "big")
|
||||
if tag == 0x04:
|
||||
return value.decode("utf-8", errors="replace")
|
||||
if tag == 0x06:
|
||||
return ".".join(str(part) for part in _decode_oid(value))
|
||||
if tag in {0x05, 0x80, 0x81, 0x82}:
|
||||
return None
|
||||
return value
|
||||
+22
-1
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, JSON, String, Text, func
|
||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, JSON, String, Text, func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
@@ -43,6 +43,27 @@ class CheckResult(Base):
|
||||
observed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class Metric(Base):
|
||||
__tablename__ = "metrics"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
monitor_id: Mapped[int] = mapped_column(ForeignKey("monitors.id", ondelete="CASCADE"))
|
||||
name: Mapped[str] = mapped_column(String(120))
|
||||
value: Mapped[float] = mapped_column(Float)
|
||||
unit: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
observed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credentials"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(160))
|
||||
credential_type: Mapped[str] = mapped_column(String(64))
|
||||
encrypted_secret: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
extra: Mapped[dict] = mapped_column("metadata", JSON, default=dict)
|
||||
|
||||
|
||||
class AlertRule(Base):
|
||||
__tablename__ = "alert_rules"
|
||||
|
||||
|
||||
+45
-4
@@ -7,11 +7,12 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
import httpx
|
||||
|
||||
from app.collectors.snmp import SnmpCheckConfig, SnmpCheckResult, run_snmp_check
|
||||
from app.collectors.website import WebsiteCheckConfig, run_website_check
|
||||
from app.collectors.network import PingCheckConfig, TcpCheckConfig, run_ping_check, run_tcp_check
|
||||
from app.config import settings
|
||||
from app.db import session_scope
|
||||
from app.models import AlertRule, Asset, CheckResult, Incident, Monitor, NotificationChannel
|
||||
from app.models import AlertRule, Asset, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
|
||||
from app.secrets import decrypt_secret
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -47,7 +48,7 @@ class Scheduler:
|
||||
def _load_due_monitors(self, db: Session) -> list[Monitor]:
|
||||
now = datetime.now(UTC)
|
||||
monitors = db.scalars(
|
||||
select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp"])).order_by(Monitor.id).limit(50)
|
||||
select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp", "snmp"])).order_by(Monitor.id).limit(50)
|
||||
).all()
|
||||
due: list[Monitor] = []
|
||||
for monitor in monitors:
|
||||
@@ -60,7 +61,7 @@ class Scheduler:
|
||||
return due
|
||||
|
||||
async def _run_monitor(self, db: Session, monitor: Monitor) -> None:
|
||||
result = await self._collect_monitor_result(monitor)
|
||||
result = await self._collect_monitor_result(db, monitor)
|
||||
now = datetime.now(UTC)
|
||||
|
||||
monitor.status = result.status
|
||||
@@ -76,6 +77,18 @@ class Scheduler:
|
||||
)
|
||||
db.flush()
|
||||
|
||||
for metric in getattr(result, "metrics", []):
|
||||
db.add(
|
||||
Metric(
|
||||
monitor_id=monitor.id,
|
||||
name=metric.name,
|
||||
value=metric.value,
|
||||
unit=metric.unit,
|
||||
observed_at=now,
|
||||
)
|
||||
)
|
||||
db.flush()
|
||||
|
||||
if monitor.asset_id is not None:
|
||||
asset = db.get(Asset, monitor.asset_id)
|
||||
if asset is not None:
|
||||
@@ -87,7 +100,7 @@ class Scheduler:
|
||||
|
||||
logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms)
|
||||
|
||||
async def _collect_monitor_result(self, monitor: Monitor):
|
||||
async def _collect_monitor_result(self, db: Session, monitor: Monitor):
|
||||
if monitor.monitor_type == "http":
|
||||
config = WebsiteCheckConfig(
|
||||
url=monitor.target,
|
||||
@@ -115,8 +128,36 @@ class Scheduler:
|
||||
)
|
||||
return await run_tcp_check(config)
|
||||
|
||||
if monitor.monitor_type == "snmp":
|
||||
return await self._collect_snmp_monitor_result(db, monitor)
|
||||
|
||||
raise ValueError(f"Unsupported monitor type: {monitor.monitor_type}")
|
||||
|
||||
async def _collect_snmp_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
|
||||
profile_id = monitor.config.get("credential_profile_id")
|
||||
if not isinstance(profile_id, int):
|
||||
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile is not configured")
|
||||
|
||||
profile = db.get(Credential, profile_id)
|
||||
if profile is None or profile.credential_type != "snmp":
|
||||
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile was not found")
|
||||
|
||||
community = decrypt_secret(profile.encrypted_secret)
|
||||
if not community:
|
||||
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile has no usable community string")
|
||||
|
||||
extra = dict(profile.extra or {})
|
||||
config = SnmpCheckConfig(
|
||||
host=monitor.target,
|
||||
community=community,
|
||||
item_id=str(monitor.config.get("item_id") or ""),
|
||||
item_type=str(monitor.config.get("item_type") or ""),
|
||||
port=int(extra.get("port") or 161),
|
||||
timeout_seconds=float(extra.get("timeout_seconds") or 5),
|
||||
retries=int(extra.get("retries") or 1),
|
||||
)
|
||||
return await run_snmp_check(config)
|
||||
|
||||
async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None:
|
||||
open_incident = db.scalar(
|
||||
select(Incident).where(
|
||||
|
||||
@@ -8,9 +8,10 @@ from sqlalchemy import create_engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.collectors.snmp import SnmpCheckResult, SnmpMetricValue
|
||||
from app.collectors.website import WebsiteCheckResult
|
||||
from app.config import settings
|
||||
from app.models import AlertRule, Base, CheckResult, Incident, Monitor, NotificationChannel
|
||||
from app.models import AlertRule, Base, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
@@ -25,7 +26,7 @@ class RecordingScheduler(Scheduler):
|
||||
self.results = list(results or [])
|
||||
self.posts: list[dict[str, str]] = []
|
||||
|
||||
async def _collect_monitor_result(self, monitor: Monitor) -> WebsiteCheckResult:
|
||||
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> WebsiteCheckResult:
|
||||
return self.results.pop(0)
|
||||
|
||||
async def _post_webhook(self, url: str, message: str, username: str) -> None:
|
||||
@@ -159,3 +160,93 @@ class SchedulerTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
open_incidents = self.db.scalars(select(Incident).where(Incident.status == "open")).all()
|
||||
assert open_incidents == []
|
||||
|
||||
async def test_scheduler_includes_snmp_monitors_as_due(self) -> None:
|
||||
scheduler = RecordingScheduler()
|
||||
snmp_monitor = Monitor(
|
||||
name="Core Switch uplink status",
|
||||
monitor_type="snmp",
|
||||
target="192.0.2.10",
|
||||
config={"credential_profile_id": 1, "item_id": "interface.1.status", "item_type": "interface_status"},
|
||||
interval_seconds=60,
|
||||
status="unknown",
|
||||
)
|
||||
self.db.add(snmp_monitor)
|
||||
self.db.flush()
|
||||
|
||||
due = scheduler._load_due_monitors(self.db)
|
||||
|
||||
assert snmp_monitor in due
|
||||
|
||||
async def test_scheduler_records_snmp_metrics(self) -> None:
|
||||
monitor = Monitor(
|
||||
name="Core Switch uplink traffic",
|
||||
monitor_type="snmp",
|
||||
target="192.0.2.10",
|
||||
config={"credential_profile_id": 1, "item_id": "interface.1.traffic", "item_type": "interface_traffic"},
|
||||
interval_seconds=60,
|
||||
status="unknown",
|
||||
)
|
||||
self.db.add(monitor)
|
||||
self.db.flush()
|
||||
|
||||
class MetricScheduler(Scheduler):
|
||||
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
|
||||
return SnmpCheckResult(
|
||||
status="up",
|
||||
response_time_ms=12,
|
||||
message="Interface traffic counters collected",
|
||||
metrics=[
|
||||
SnmpMetricValue(name="in_octets", value=1000, unit="bytes"),
|
||||
SnmpMetricValue(name="out_octets", value=2000, unit="bytes"),
|
||||
],
|
||||
)
|
||||
|
||||
await MetricScheduler()._run_monitor(self.db, monitor)
|
||||
|
||||
assert monitor.status == "up"
|
||||
metrics = self.db.scalars(select(Metric).where(Metric.monitor_id == monitor.id).order_by(Metric.name)).all()
|
||||
assert [(metric.name, metric.value, metric.unit) for metric in metrics] == [
|
||||
("in_octets", 1000.0, "bytes"),
|
||||
("out_octets", 2000.0, "bytes"),
|
||||
]
|
||||
|
||||
async def test_snmp_monitor_uses_saved_profile_secret(self) -> None:
|
||||
profile = Credential(
|
||||
name="Core Switch",
|
||||
credential_type="snmp",
|
||||
encrypted_secret=encrypt_secret("private-community"),
|
||||
extra={"port": 1161, "timeout_seconds": 3, "retries": 2},
|
||||
)
|
||||
monitor = Monitor(
|
||||
name="Core Switch uptime",
|
||||
monitor_type="snmp",
|
||||
target="192.0.2.10",
|
||||
config={"credential_profile_id": 1, "item_id": "device.uptime", "item_type": "device_uptime"},
|
||||
interval_seconds=60,
|
||||
status="unknown",
|
||||
)
|
||||
self.db.add_all([profile, monitor])
|
||||
self.db.flush()
|
||||
calls = []
|
||||
|
||||
async def fake_run_snmp_check(config):
|
||||
calls.append(config)
|
||||
return SnmpCheckResult(status="up", response_time_ms=10, message="Device uptime is 60 seconds")
|
||||
|
||||
import app.scheduler as scheduler_module
|
||||
|
||||
original = scheduler_module.run_snmp_check
|
||||
scheduler_module.run_snmp_check = fake_run_snmp_check
|
||||
try:
|
||||
result = await Scheduler()._collect_snmp_monitor_result(self.db, monitor)
|
||||
finally:
|
||||
scheduler_module.run_snmp_check = original
|
||||
|
||||
assert result.status == "up"
|
||||
assert len(calls) == 1
|
||||
assert calls[0].host == "192.0.2.10"
|
||||
assert calls[0].community == "private-community"
|
||||
assert calls[0].port == 1161
|
||||
assert calls[0].timeout_seconds == 3
|
||||
assert calls[0].retries == 2
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.collectors.snmp import (
|
||||
IF_ADMIN_STATUS,
|
||||
IF_HC_IN_OCTETS,
|
||||
IF_HC_OUT_OCTETS,
|
||||
IF_IN_DISCARDS,
|
||||
IF_IN_ERRORS,
|
||||
IF_OPER_STATUS,
|
||||
IF_OUT_DISCARDS,
|
||||
IF_OUT_ERRORS,
|
||||
SYS_UPTIME,
|
||||
SnmpCheckConfig,
|
||||
_with_index,
|
||||
run_snmp_check,
|
||||
)
|
||||
|
||||
|
||||
class SnmpCollectorTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_collects_device_uptime(self) -> None:
|
||||
with patch("app.collectors.snmp.SnmpV2Client") as client_class:
|
||||
client_class.return_value.get_many.return_value = {SYS_UPTIME: 123_400}
|
||||
|
||||
result = await run_snmp_check(
|
||||
SnmpCheckConfig(
|
||||
host="192.0.2.10",
|
||||
community="private-community",
|
||||
item_id="device.uptime",
|
||||
item_type="device_uptime",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.status == "up"
|
||||
assert result.message == "Device uptime is 1234 seconds"
|
||||
assert [(metric.name, metric.value, metric.unit) for metric in result.metrics] == [
|
||||
("uptime_seconds", 1234.0, "seconds")
|
||||
]
|
||||
|
||||
async def test_collects_interface_status(self) -> None:
|
||||
admin_oid = _with_index(IF_ADMIN_STATUS, 7)
|
||||
oper_oid = _with_index(IF_OPER_STATUS, 7)
|
||||
with patch("app.collectors.snmp.SnmpV2Client") as client_class:
|
||||
client_class.return_value.get_many.return_value = {admin_oid: 1, oper_oid: 2}
|
||||
|
||||
result = await run_snmp_check(
|
||||
SnmpCheckConfig(
|
||||
host="192.0.2.10",
|
||||
community="private-community",
|
||||
item_id="interface.7.status",
|
||||
item_type="interface_status",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.status == "down"
|
||||
assert result.message == "Interface admin up, operational down"
|
||||
assert [(metric.name, metric.value, metric.unit) for metric in result.metrics] == [
|
||||
("admin_status", 1.0, None),
|
||||
("oper_status", 2.0, None),
|
||||
]
|
||||
|
||||
async def test_collects_interface_traffic_from_high_capacity_counters(self) -> None:
|
||||
in_oid = _with_index(IF_HC_IN_OCTETS, 3)
|
||||
out_oid = _with_index(IF_HC_OUT_OCTETS, 3)
|
||||
with patch("app.collectors.snmp.SnmpV2Client") as client_class:
|
||||
client_class.return_value.get_many.return_value = {in_oid: 123, out_oid: 456}
|
||||
|
||||
result = await run_snmp_check(
|
||||
SnmpCheckConfig(
|
||||
host="192.0.2.10",
|
||||
community="private-community",
|
||||
item_id="interface.3.traffic",
|
||||
item_type="interface_traffic",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.status == "up"
|
||||
assert [(metric.name, metric.value, metric.unit) for metric in result.metrics] == [
|
||||
("in_octets", 123.0, "bytes"),
|
||||
("out_octets", 456.0, "bytes"),
|
||||
]
|
||||
|
||||
async def test_collects_interface_errors_and_discards(self) -> None:
|
||||
oids = [
|
||||
_with_index(IF_IN_ERRORS, 5),
|
||||
_with_index(IF_OUT_ERRORS, 5),
|
||||
_with_index(IF_IN_DISCARDS, 5),
|
||||
_with_index(IF_OUT_DISCARDS, 5),
|
||||
]
|
||||
with patch("app.collectors.snmp.SnmpV2Client") as client_class:
|
||||
client_class.return_value.get_many.return_value = {
|
||||
oids[0]: 1,
|
||||
oids[1]: 2,
|
||||
oids[2]: 3,
|
||||
oids[3]: 4,
|
||||
}
|
||||
|
||||
result = await run_snmp_check(
|
||||
SnmpCheckConfig(
|
||||
host="192.0.2.10",
|
||||
community="private-community",
|
||||
item_id="interface.5.errors",
|
||||
item_type="interface_errors",
|
||||
)
|
||||
)
|
||||
|
||||
assert result.status == "up"
|
||||
assert [(metric.name, metric.value, metric.unit) for metric in result.metrics] == [
|
||||
("in_errors", 1.0, "count"),
|
||||
("out_errors", 2.0, "count"),
|
||||
("in_discards", 3.0, "count"),
|
||||
("out_discards", 4.0, "count"),
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user