Files
OrbitWard/worker/app/scheduler.py
T
2026-05-26 21:24:54 -06:00

301 lines
12 KiB
Python

import asyncio
import logging
from datetime import UTC, datetime, timedelta
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
import httpx
from app.collectors.snmp import SnmpCheckConfig, SnmpCheckResult, run_snmp_check
from app.collectors.website import WebsiteCheckConfig, run_website_check
from app.collectors.network import PingCheckConfig, TcpCheckConfig, run_ping_check, run_tcp_check
from app.config import settings
from app.db import session_scope
from app.models import AlertRule, Asset, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
from app.secrets import decrypt_secret
logger = logging.getLogger(__name__)
class Scheduler:
def __init__(self, poll_interval_seconds: int = 10) -> None:
self.poll_interval_seconds = poll_interval_seconds
self._stopped = asyncio.Event()
async def run(self) -> None:
logger.info("OrbitWard worker started for %s", settings.orbitward_env)
while not self._stopped.is_set():
await self.tick()
try:
await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval_seconds)
except TimeoutError:
continue
async def tick(self) -> None:
try:
with session_scope() as db:
due_monitors = self._load_due_monitors(db)
for monitor in due_monitors:
await self._run_monitor(db, monitor)
db.commit()
except SQLAlchemyError:
logger.exception("Worker tick failed while talking to the database")
def stop(self) -> None:
self._stopped.set()
def _load_due_monitors(self, db: Session) -> list[Monitor]:
now = datetime.now(UTC)
monitors = db.scalars(
select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp", "snmp"])).order_by(Monitor.id).limit(50)
).all()
due: list[Monitor] = []
for monitor in monitors:
if monitor.last_checked_at is None:
due.append(monitor)
continue
next_due_at = monitor.last_checked_at + timedelta(seconds=monitor.interval_seconds)
if next_due_at <= now:
due.append(monitor)
return due
async def _run_monitor(self, db: Session, monitor: Monitor) -> None:
result = await self._collect_monitor_result(db, monitor)
now = datetime.now(UTC)
monitor.status = result.status
monitor.last_checked_at = now
db.add(
CheckResult(
monitor_id=monitor.id,
status=result.status,
response_time_ms=result.response_time_ms,
message=result.message,
observed_at=now,
)
)
db.flush()
for metric in getattr(result, "metrics", []):
db.add(
Metric(
monitor_id=monitor.id,
name=metric.name,
value=metric.value,
unit=metric.unit,
observed_at=now,
)
)
db.flush()
if monitor.asset_id is not None:
asset = db.get(Asset, monitor.asset_id)
if asset is not None:
asset.status = result.status
rules = db.scalars(select(AlertRule).where(AlertRule.monitor_id == monitor.id, AlertRule.is_enabled.is_(True))).all()
for rule in rules:
await self._evaluate_rule(db, monitor, rule, now, result.message)
logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms)
async def _collect_monitor_result(self, db: Session, monitor: Monitor):
if monitor.monitor_type == "http":
config = WebsiteCheckConfig(
url=monitor.target,
expected_status=int(monitor.config.get("expected_status", 200)),
expected_text=monitor.config.get("expected_text") or None,
unexpected_text=monitor.config.get("unexpected_text") or None,
timeout_seconds=float(monitor.config.get("timeout_seconds", 10)),
check_tls_expiry=bool(monitor.config.get("check_tls_expiry", False)),
tls_warning_days=int(monitor.config.get("tls_warning_days", 30)),
)
return await run_website_check(config)
if monitor.monitor_type == "ping":
config = PingCheckConfig(
host=monitor.target,
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
)
return await run_ping_check(config)
if monitor.monitor_type == "tcp":
config = TcpCheckConfig(
host=str(monitor.config.get("host") or monitor.target),
port=int(monitor.config.get("port")),
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
)
return await run_tcp_check(config)
if monitor.monitor_type == "snmp":
return await self._collect_snmp_monitor_result(db, monitor)
raise ValueError(f"Unsupported monitor type: {monitor.monitor_type}")
async def _collect_snmp_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
profile_id = monitor.config.get("credential_profile_id")
if not isinstance(profile_id, int):
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile is not configured")
profile = db.get(Credential, profile_id)
if profile is None or profile.credential_type != "snmp":
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile was not found")
community = decrypt_secret(profile.encrypted_secret)
if not community:
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile has no usable community string")
extra = dict(profile.extra or {})
config = SnmpCheckConfig(
host=monitor.target,
community=community,
item_id=str(monitor.config.get("item_id") or ""),
item_type=str(monitor.config.get("item_type") or ""),
label=monitor.config.get("label") if isinstance(monitor.config.get("label"), str) else None,
unit=monitor.config.get("unit") if isinstance(monitor.config.get("unit"), str) else None,
port=int(extra.get("port") or 161),
timeout_seconds=float(extra.get("timeout_seconds") or 5),
retries=int(extra.get("retries") or 1),
)
return await run_snmp_check(config)
async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None:
open_incident = db.scalar(
select(Incident).where(
Incident.monitor_id == monitor.id,
Incident.alert_rule_id == rule.id,
Incident.status == "open",
)
)
if monitor.status == "up":
if open_incident is not None:
open_incident.status = "resolved"
open_incident.resolved_at = now
open_incident.details = {**(open_incident.details or {}), "recovery_message": message}
await self._send_incident_notifications(db, open_incident, monitor, "resolved", now)
return
recent_statuses = list(
db.scalars(
select(CheckResult.status)
.where(CheckResult.monitor_id == monitor.id)
.order_by(CheckResult.observed_at.desc())
.limit(rule.failure_threshold)
)
)
threshold_met = len(recent_statuses) >= rule.failure_threshold and all(status != "up" for status in recent_statuses)
if threshold_met and open_incident is None:
if rule.cooldown_seconds > 0:
latest_incident = db.scalar(
select(Incident)
.where(
Incident.monitor_id == monitor.id,
Incident.alert_rule_id == rule.id,
)
.order_by(Incident.opened_at.desc())
.limit(1)
)
if (
latest_incident is not None
and self._as_utc(latest_incident.opened_at) + timedelta(seconds=rule.cooldown_seconds) > now
):
return
incident = Incident(
asset_id=monitor.asset_id,
monitor_id=monitor.id,
alert_rule_id=rule.id,
title=f"{monitor.name} is failing",
severity=rule.severity,
status="open",
opened_at=now,
details={"last_message": message, "failure_threshold": rule.failure_threshold},
)
db.add(incident)
db.flush()
await self._send_incident_notifications(db, incident, monitor, "opened", now)
async def _send_incident_notifications(
self,
db: Session,
incident: Incident,
monitor: Monitor,
event_type: str,
now: datetime,
) -> None:
state_key = "opened_sent_at" if event_type == "opened" else "resolved_sent_at"
notification_state = dict((incident.details or {}).get("notification_state") or {})
if notification_state.get(state_key):
return
channels = db.scalars(
select(NotificationChannel).where(
NotificationChannel.is_enabled.is_(True),
NotificationChannel.channel_type.in_(["generic_webhook", "webhook", "mattermost", "zoom", "zoom_team_chat"]),
)
).all()
if not channels:
return
sent_channels: list[str] = []
for channel in channels:
url = decrypt_secret(channel.encrypted_secret)
if not url:
logger.warning("Skipping notification channel %s because its secret cannot be decrypted", channel.id)
continue
try:
await self._post_webhook(
url,
self._format_incident_message(incident, monitor, event_type),
str((channel.settings or {}).get("username") or "OrbitWard"),
)
except httpx.HTTPError:
logger.exception("Notification delivery failed for channel %s", channel.id)
continue
sent_channels.append(channel.name)
if sent_channels:
notification_state[state_key] = now.isoformat()
history = list((incident.details or {}).get("notification_history") or [])
history.append({"event": event_type, "sent_at": now.isoformat(), "channels": sent_channels})
incident.details = {**(incident.details or {}), "notification_state": notification_state, "notification_history": history}
async def _post_webhook(self, url: str, message: str, username: str) -> None:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(url, json={"username": username, "text": message})
response.raise_for_status()
@staticmethod
def _as_utc(value: datetime) -> datetime:
if value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value.astimezone(UTC)
def _format_incident_message(self, incident: Incident, monitor: Monitor, event_type: str) -> str:
if event_type == "resolved":
title = f"RESOLVED: {monitor.name} recovered"
body = [
title,
"",
f"Monitor: {monitor.name}",
f"Target: {monitor.target}",
f"Resolved: {incident.resolved_at or datetime.now(UTC)}",
]
else:
title = f"{incident.severity.upper()}: {incident.title}"
body = [
title,
"",
f"Monitor: {monitor.name}",
f"Target: {monitor.target}",
f"Status: {monitor.status}",
f"Started: {incident.opened_at}",
]
last_message = (incident.details or {}).get("last_message")
if last_message:
body.append(f"Last response: {last_message}")
body.extend(["", f"View in OrbitWard: {settings.frontend_url}/incidents/{incident.id}"])
return "\n".join(str(line) for line in body)