253 lines
9.9 KiB
Python
253 lines
9.9 KiB
Python
import base64
|
|
import hashlib
|
|
import unittest
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import create_engine, select
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.collectors.snmp import SnmpCheckResult, SnmpMetricValue
|
|
from app.collectors.website import WebsiteCheckResult
|
|
from app.config import settings
|
|
from app.models import AlertRule, Base, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
|
|
from app.scheduler import Scheduler
|
|
|
|
|
|
def encrypt_secret(value: str) -> str:
|
|
digest = hashlib.sha256(settings.orbitward_secret_key.encode("utf-8")).digest()
|
|
return Fernet(base64.urlsafe_b64encode(digest)).encrypt(value.encode("utf-8")).decode("utf-8")
|
|
|
|
|
|
class RecordingScheduler(Scheduler):
|
|
def __init__(self, results: list[WebsiteCheckResult] | None = None) -> None:
|
|
super().__init__()
|
|
self.results = list(results or [])
|
|
self.posts: list[dict[str, str]] = []
|
|
|
|
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> WebsiteCheckResult:
|
|
return self.results.pop(0)
|
|
|
|
async def _post_webhook(self, url: str, message: str, username: str) -> None:
|
|
self.posts.append({"url": url, "message": message, "username": username})
|
|
|
|
|
|
class SchedulerTestCase(unittest.IsolatedAsyncioTestCase):
|
|
def setUp(self) -> None:
|
|
self.engine = create_engine(
|
|
"sqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
Base.metadata.create_all(bind=self.engine)
|
|
self.session_factory = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
|
|
self.db: Session = self.session_factory()
|
|
|
|
def tearDown(self) -> None:
|
|
self.db.close()
|
|
Base.metadata.drop_all(bind=self.engine)
|
|
self.engine.dispose()
|
|
|
|
def create_monitor_with_rule(self, *, failure_threshold: int = 2, cooldown_seconds: int = 0) -> tuple[Monitor, AlertRule]:
|
|
monitor = Monitor(
|
|
name="Example Site",
|
|
monitor_type="http",
|
|
target="https://example.com",
|
|
config={"expected_status": 200, "timeout_seconds": 5},
|
|
interval_seconds=60,
|
|
status="unknown",
|
|
)
|
|
self.db.add(monitor)
|
|
self.db.flush()
|
|
|
|
rule = AlertRule(
|
|
monitor_id=monitor.id,
|
|
name="Example Site failure",
|
|
severity="critical",
|
|
condition={"type": "status_not_up"},
|
|
failure_threshold=failure_threshold,
|
|
cooldown_seconds=cooldown_seconds,
|
|
is_enabled=True,
|
|
)
|
|
self.db.add(rule)
|
|
self.db.flush()
|
|
return monitor, rule
|
|
|
|
async def test_alert_evaluation_opens_incident_after_failure_threshold(self) -> None:
|
|
monitor, rule = self.create_monitor_with_rule(failure_threshold=2)
|
|
scheduler = RecordingScheduler(
|
|
[
|
|
WebsiteCheckResult(status="down", response_time_ms=100, message="HTTP 500"),
|
|
WebsiteCheckResult(status="down", response_time_ms=110, message="HTTP 500 again"),
|
|
]
|
|
)
|
|
|
|
await scheduler._run_monitor(self.db, monitor)
|
|
assert self.db.scalars(select(Incident)).all() == []
|
|
|
|
await scheduler._run_monitor(self.db, monitor)
|
|
|
|
incident = self.db.scalar(select(Incident))
|
|
assert incident is not None
|
|
assert incident.monitor_id == monitor.id
|
|
assert incident.alert_rule_id == rule.id
|
|
assert incident.status == "open"
|
|
assert incident.severity == "critical"
|
|
assert incident.details["last_message"] == "HTTP 500 again"
|
|
assert incident.details["failure_threshold"] == 2
|
|
|
|
async def test_recovery_resolves_open_incident_and_sends_notifications_once(self) -> None:
|
|
monitor, rule = self.create_monitor_with_rule(failure_threshold=1)
|
|
channel = NotificationChannel(
|
|
name="Ops Webhook",
|
|
channel_type="generic_webhook",
|
|
settings={"username": "OrbitWard"},
|
|
encrypted_secret=encrypt_secret("https://hooks.example.test/orbitward"),
|
|
is_enabled=True,
|
|
)
|
|
self.db.add(channel)
|
|
self.db.flush()
|
|
scheduler = RecordingScheduler(
|
|
[
|
|
WebsiteCheckResult(status="down", response_time_ms=100, message="HTTP 500"),
|
|
WebsiteCheckResult(status="up", response_time_ms=80, message="Website check passed"),
|
|
]
|
|
)
|
|
|
|
await scheduler._run_monitor(self.db, monitor)
|
|
|
|
incident = self.db.scalar(select(Incident))
|
|
assert incident is not None
|
|
assert incident.status == "open"
|
|
assert len(scheduler.posts) == 1
|
|
assert scheduler.posts[0]["url"] == "https://hooks.example.test/orbitward"
|
|
assert scheduler.posts[0]["username"] == "OrbitWard"
|
|
assert incident.details["notification_history"][0]["event"] == "opened"
|
|
|
|
await scheduler._send_incident_notifications(self.db, incident, monitor, "opened", datetime.now(UTC))
|
|
assert len(scheduler.posts) == 1
|
|
|
|
await scheduler._run_monitor(self.db, monitor)
|
|
|
|
assert incident.status == "resolved"
|
|
assert incident.resolved_at is not None
|
|
assert incident.details["recovery_message"] == "Website check passed"
|
|
assert len(scheduler.posts) == 2
|
|
assert incident.details["notification_history"][1]["event"] == "resolved"
|
|
|
|
async def test_alert_cooldown_suppresses_new_incident_after_recent_resolution(self) -> None:
|
|
monitor, rule = self.create_monitor_with_rule(failure_threshold=1, cooldown_seconds=300)
|
|
now = datetime.now(UTC)
|
|
self.db.add(
|
|
Incident(
|
|
monitor_id=monitor.id,
|
|
alert_rule_id=rule.id,
|
|
title="Example Site is failing",
|
|
severity="critical",
|
|
status="resolved",
|
|
opened_at=now - timedelta(seconds=60),
|
|
resolved_at=now - timedelta(seconds=30),
|
|
details={},
|
|
)
|
|
)
|
|
self.db.add(CheckResult(monitor_id=monitor.id, status="down", response_time_ms=100, message="HTTP 500", observed_at=now))
|
|
monitor.status = "down"
|
|
self.db.flush()
|
|
scheduler = RecordingScheduler()
|
|
|
|
await scheduler._evaluate_rule(self.db, monitor, rule, now, "HTTP 500")
|
|
|
|
open_incidents = self.db.scalars(select(Incident).where(Incident.status == "open")).all()
|
|
assert open_incidents == []
|
|
|
|
async def test_scheduler_includes_snmp_monitors_as_due(self) -> None:
|
|
scheduler = RecordingScheduler()
|
|
snmp_monitor = Monitor(
|
|
name="Core Switch uplink status",
|
|
monitor_type="snmp",
|
|
target="192.0.2.10",
|
|
config={"credential_profile_id": 1, "item_id": "interface.1.status", "item_type": "interface_status"},
|
|
interval_seconds=60,
|
|
status="unknown",
|
|
)
|
|
self.db.add(snmp_monitor)
|
|
self.db.flush()
|
|
|
|
due = scheduler._load_due_monitors(self.db)
|
|
|
|
assert snmp_monitor in due
|
|
|
|
async def test_scheduler_records_snmp_metrics(self) -> None:
|
|
monitor = Monitor(
|
|
name="Core Switch uplink traffic",
|
|
monitor_type="snmp",
|
|
target="192.0.2.10",
|
|
config={"credential_profile_id": 1, "item_id": "interface.1.traffic", "item_type": "interface_traffic"},
|
|
interval_seconds=60,
|
|
status="unknown",
|
|
)
|
|
self.db.add(monitor)
|
|
self.db.flush()
|
|
|
|
class MetricScheduler(Scheduler):
|
|
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
|
|
return SnmpCheckResult(
|
|
status="up",
|
|
response_time_ms=12,
|
|
message="Interface traffic counters collected",
|
|
metrics=[
|
|
SnmpMetricValue(name="in_octets", value=1000, unit="bytes"),
|
|
SnmpMetricValue(name="out_octets", value=2000, unit="bytes"),
|
|
],
|
|
)
|
|
|
|
await MetricScheduler()._run_monitor(self.db, monitor)
|
|
|
|
assert monitor.status == "up"
|
|
metrics = self.db.scalars(select(Metric).where(Metric.monitor_id == monitor.id).order_by(Metric.name)).all()
|
|
assert [(metric.name, metric.value, metric.unit) for metric in metrics] == [
|
|
("in_octets", 1000.0, "bytes"),
|
|
("out_octets", 2000.0, "bytes"),
|
|
]
|
|
|
|
async def test_snmp_monitor_uses_saved_profile_secret(self) -> None:
|
|
profile = Credential(
|
|
name="Core Switch",
|
|
credential_type="snmp",
|
|
encrypted_secret=encrypt_secret("private-community"),
|
|
extra={"port": 1161, "timeout_seconds": 3, "retries": 2},
|
|
)
|
|
monitor = Monitor(
|
|
name="Core Switch uptime",
|
|
monitor_type="snmp",
|
|
target="192.0.2.10",
|
|
config={"credential_profile_id": 1, "item_id": "device.uptime", "item_type": "device_uptime"},
|
|
interval_seconds=60,
|
|
status="unknown",
|
|
)
|
|
self.db.add_all([profile, monitor])
|
|
self.db.flush()
|
|
calls = []
|
|
|
|
async def fake_run_snmp_check(config):
|
|
calls.append(config)
|
|
return SnmpCheckResult(status="up", response_time_ms=10, message="Device uptime is 60 seconds")
|
|
|
|
import app.scheduler as scheduler_module
|
|
|
|
original = scheduler_module.run_snmp_check
|
|
scheduler_module.run_snmp_check = fake_run_snmp_check
|
|
try:
|
|
result = await Scheduler()._collect_snmp_monitor_result(self.db, monitor)
|
|
finally:
|
|
scheduler_module.run_snmp_check = original
|
|
|
|
assert result.status == "up"
|
|
assert len(calls) == 1
|
|
assert calls[0].host == "192.0.2.10"
|
|
assert calls[0].community == "private-community"
|
|
assert calls[0].port == 1161
|
|
assert calls[0].timeout_seconds == 3
|
|
assert calls[0].retries == 2
|