Files
OrbitWard/worker/tests/test_scheduler.py
T
2026-05-26 21:24:54 -06:00

253 lines
9.9 KiB
Python

import base64
import hashlib
import unittest
from datetime import UTC, datetime, timedelta
from cryptography.fernet import Fernet
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.collectors.snmp import SnmpCheckResult, SnmpMetricValue
from app.collectors.website import WebsiteCheckResult
from app.config import settings
from app.models import AlertRule, Base, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
from app.scheduler import Scheduler
def encrypt_secret(value: str) -> str:
digest = hashlib.sha256(settings.orbitward_secret_key.encode("utf-8")).digest()
return Fernet(base64.urlsafe_b64encode(digest)).encrypt(value.encode("utf-8")).decode("utf-8")
class RecordingScheduler(Scheduler):
def __init__(self, results: list[WebsiteCheckResult] | None = None) -> None:
super().__init__()
self.results = list(results or [])
self.posts: list[dict[str, str]] = []
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> WebsiteCheckResult:
return self.results.pop(0)
async def _post_webhook(self, url: str, message: str, username: str) -> None:
self.posts.append({"url": url, "message": message, "username": username})
class SchedulerTestCase(unittest.IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=self.engine)
self.session_factory = sessionmaker(bind=self.engine, autoflush=False, autocommit=False)
self.db: Session = self.session_factory()
def tearDown(self) -> None:
self.db.close()
Base.metadata.drop_all(bind=self.engine)
self.engine.dispose()
def create_monitor_with_rule(self, *, failure_threshold: int = 2, cooldown_seconds: int = 0) -> tuple[Monitor, AlertRule]:
monitor = Monitor(
name="Example Site",
monitor_type="http",
target="https://example.com",
config={"expected_status": 200, "timeout_seconds": 5},
interval_seconds=60,
status="unknown",
)
self.db.add(monitor)
self.db.flush()
rule = AlertRule(
monitor_id=monitor.id,
name="Example Site failure",
severity="critical",
condition={"type": "status_not_up"},
failure_threshold=failure_threshold,
cooldown_seconds=cooldown_seconds,
is_enabled=True,
)
self.db.add(rule)
self.db.flush()
return monitor, rule
async def test_alert_evaluation_opens_incident_after_failure_threshold(self) -> None:
monitor, rule = self.create_monitor_with_rule(failure_threshold=2)
scheduler = RecordingScheduler(
[
WebsiteCheckResult(status="down", response_time_ms=100, message="HTTP 500"),
WebsiteCheckResult(status="down", response_time_ms=110, message="HTTP 500 again"),
]
)
await scheduler._run_monitor(self.db, monitor)
assert self.db.scalars(select(Incident)).all() == []
await scheduler._run_monitor(self.db, monitor)
incident = self.db.scalar(select(Incident))
assert incident is not None
assert incident.monitor_id == monitor.id
assert incident.alert_rule_id == rule.id
assert incident.status == "open"
assert incident.severity == "critical"
assert incident.details["last_message"] == "HTTP 500 again"
assert incident.details["failure_threshold"] == 2
async def test_recovery_resolves_open_incident_and_sends_notifications_once(self) -> None:
monitor, rule = self.create_monitor_with_rule(failure_threshold=1)
channel = NotificationChannel(
name="Ops Webhook",
channel_type="generic_webhook",
settings={"username": "OrbitWard"},
encrypted_secret=encrypt_secret("https://hooks.example.test/orbitward"),
is_enabled=True,
)
self.db.add(channel)
self.db.flush()
scheduler = RecordingScheduler(
[
WebsiteCheckResult(status="down", response_time_ms=100, message="HTTP 500"),
WebsiteCheckResult(status="up", response_time_ms=80, message="Website check passed"),
]
)
await scheduler._run_monitor(self.db, monitor)
incident = self.db.scalar(select(Incident))
assert incident is not None
assert incident.status == "open"
assert len(scheduler.posts) == 1
assert scheduler.posts[0]["url"] == "https://hooks.example.test/orbitward"
assert scheduler.posts[0]["username"] == "OrbitWard"
assert incident.details["notification_history"][0]["event"] == "opened"
await scheduler._send_incident_notifications(self.db, incident, monitor, "opened", datetime.now(UTC))
assert len(scheduler.posts) == 1
await scheduler._run_monitor(self.db, monitor)
assert incident.status == "resolved"
assert incident.resolved_at is not None
assert incident.details["recovery_message"] == "Website check passed"
assert len(scheduler.posts) == 2
assert incident.details["notification_history"][1]["event"] == "resolved"
async def test_alert_cooldown_suppresses_new_incident_after_recent_resolution(self) -> None:
monitor, rule = self.create_monitor_with_rule(failure_threshold=1, cooldown_seconds=300)
now = datetime.now(UTC)
self.db.add(
Incident(
monitor_id=monitor.id,
alert_rule_id=rule.id,
title="Example Site is failing",
severity="critical",
status="resolved",
opened_at=now - timedelta(seconds=60),
resolved_at=now - timedelta(seconds=30),
details={},
)
)
self.db.add(CheckResult(monitor_id=monitor.id, status="down", response_time_ms=100, message="HTTP 500", observed_at=now))
monitor.status = "down"
self.db.flush()
scheduler = RecordingScheduler()
await scheduler._evaluate_rule(self.db, monitor, rule, now, "HTTP 500")
open_incidents = self.db.scalars(select(Incident).where(Incident.status == "open")).all()
assert open_incidents == []
async def test_scheduler_includes_snmp_monitors_as_due(self) -> None:
scheduler = RecordingScheduler()
snmp_monitor = Monitor(
name="Core Switch uplink status",
monitor_type="snmp",
target="192.0.2.10",
config={"credential_profile_id": 1, "item_id": "interface.1.status", "item_type": "interface_status"},
interval_seconds=60,
status="unknown",
)
self.db.add(snmp_monitor)
self.db.flush()
due = scheduler._load_due_monitors(self.db)
assert snmp_monitor in due
async def test_scheduler_records_snmp_metrics(self) -> None:
monitor = Monitor(
name="Core Switch uplink traffic",
monitor_type="snmp",
target="192.0.2.10",
config={"credential_profile_id": 1, "item_id": "interface.1.traffic", "item_type": "interface_traffic"},
interval_seconds=60,
status="unknown",
)
self.db.add(monitor)
self.db.flush()
class MetricScheduler(Scheduler):
async def _collect_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
return SnmpCheckResult(
status="up",
response_time_ms=12,
message="Interface traffic counters collected",
metrics=[
SnmpMetricValue(name="in_octets", value=1000, unit="bytes"),
SnmpMetricValue(name="out_octets", value=2000, unit="bytes"),
],
)
await MetricScheduler()._run_monitor(self.db, monitor)
assert monitor.status == "up"
metrics = self.db.scalars(select(Metric).where(Metric.monitor_id == monitor.id).order_by(Metric.name)).all()
assert [(metric.name, metric.value, metric.unit) for metric in metrics] == [
("in_octets", 1000.0, "bytes"),
("out_octets", 2000.0, "bytes"),
]
async def test_snmp_monitor_uses_saved_profile_secret(self) -> None:
profile = Credential(
name="Core Switch",
credential_type="snmp",
encrypted_secret=encrypt_secret("private-community"),
extra={"port": 1161, "timeout_seconds": 3, "retries": 2},
)
monitor = Monitor(
name="Core Switch uptime",
monitor_type="snmp",
target="192.0.2.10",
config={"credential_profile_id": 1, "item_id": "device.uptime", "item_type": "device_uptime"},
interval_seconds=60,
status="unknown",
)
self.db.add_all([profile, monitor])
self.db.flush()
calls = []
async def fake_run_snmp_check(config):
calls.append(config)
return SnmpCheckResult(status="up", response_time_ms=10, message="Device uptime is 60 seconds")
import app.scheduler as scheduler_module
original = scheduler_module.run_snmp_check
scheduler_module.run_snmp_check = fake_run_snmp_check
try:
result = await Scheduler()._collect_snmp_monitor_result(self.db, monitor)
finally:
scheduler_module.run_snmp_check = original
assert result.status == "up"
assert len(calls) == 1
assert calls[0].host == "192.0.2.10"
assert calls[0].community == "private-community"
assert calls[0].port == 1161
assert calls[0].timeout_seconds == 3
assert calls[0].retries == 2