import asyncio import logging from datetime import UTC, datetime, timedelta from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session import httpx from app.collectors.snmp import SnmpCheckConfig, SnmpCheckResult, run_snmp_check from app.collectors.website import WebsiteCheckConfig, run_website_check from app.collectors.network import PingCheckConfig, TcpCheckConfig, run_ping_check, run_tcp_check from app.config import settings from app.db import session_scope from app.models import AlertRule, Asset, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel from app.secrets import decrypt_secret logger = logging.getLogger(__name__) class Scheduler: def __init__(self, poll_interval_seconds: int = 10) -> None: self.poll_interval_seconds = poll_interval_seconds self._stopped = asyncio.Event() async def run(self) -> None: logger.info("OrbitalWard worker started for %s", settings.orbitalward_env) while not self._stopped.is_set(): await self.tick() try: await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval_seconds) except TimeoutError: continue async def tick(self) -> None: try: with session_scope() as db: due_monitors = self._load_due_monitors(db) for monitor in due_monitors: await self._run_monitor(db, monitor) db.commit() except SQLAlchemyError: logger.exception("Worker tick failed while talking to the database") def stop(self) -> None: self._stopped.set() def _load_due_monitors(self, db: Session) -> list[Monitor]: now = datetime.now(UTC) monitors = db.scalars( select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp", "snmp"])).order_by(Monitor.id).limit(50) ).all() due: list[Monitor] = [] for monitor in monitors: if monitor.last_checked_at is None: due.append(monitor) continue next_due_at = monitor.last_checked_at + timedelta(seconds=monitor.interval_seconds) if next_due_at <= now: due.append(monitor) return due async def _run_monitor(self, db: Session, monitor: Monitor) -> None: result = await self._collect_monitor_result(db, monitor) now = datetime.now(UTC) monitor.status = result.status monitor.last_checked_at = now db.add( CheckResult( monitor_id=monitor.id, status=result.status, response_time_ms=result.response_time_ms, message=result.message, observed_at=now, ) ) db.flush() for metric in getattr(result, "metrics", []): db.add( Metric( monitor_id=monitor.id, name=metric.name, value=metric.value, unit=metric.unit, observed_at=now, ) ) db.flush() if monitor.asset_id is not None: asset = db.get(Asset, monitor.asset_id) if asset is not None: asset.status = result.status rules = db.scalars(select(AlertRule).where(AlertRule.monitor_id == monitor.id, AlertRule.is_enabled.is_(True))).all() for rule in rules: await self._evaluate_rule(db, monitor, rule, now, result.message) logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms) async def _collect_monitor_result(self, db: Session, monitor: Monitor): if monitor.monitor_type == "http": config = WebsiteCheckConfig( url=monitor.target, expected_status=int(monitor.config.get("expected_status", 200)), expected_text=monitor.config.get("expected_text") or None, unexpected_text=monitor.config.get("unexpected_text") or None, timeout_seconds=float(monitor.config.get("timeout_seconds", 10)), check_tls_expiry=bool(monitor.config.get("check_tls_expiry", False)), tls_warning_days=int(monitor.config.get("tls_warning_days", 30)), ) return await run_website_check(config) if monitor.monitor_type == "ping": config = PingCheckConfig( host=monitor.target, timeout_seconds=float(monitor.config.get("timeout_seconds", 5)), ) return await run_ping_check(config) if monitor.monitor_type == "tcp": config = TcpCheckConfig( host=str(monitor.config.get("host") or monitor.target), port=int(monitor.config.get("port")), timeout_seconds=float(monitor.config.get("timeout_seconds", 5)), ) return await run_tcp_check(config) if monitor.monitor_type == "snmp": return await self._collect_snmp_monitor_result(db, monitor) raise ValueError(f"Unsupported monitor type: {monitor.monitor_type}") async def _collect_snmp_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult: profile_id = monitor.config.get("credential_profile_id") if not isinstance(profile_id, int): return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile is not configured") profile = db.get(Credential, profile_id) if profile is None or profile.credential_type != "snmp": return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile was not found") community = decrypt_secret(profile.encrypted_secret) if not community: return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile has no usable community string") extra = dict(profile.extra or {}) config = SnmpCheckConfig( host=monitor.target, community=community, item_id=str(monitor.config.get("item_id") or ""), item_type=str(monitor.config.get("item_type") or ""), port=int(extra.get("port") or 161), timeout_seconds=float(extra.get("timeout_seconds") or 5), retries=int(extra.get("retries") or 1), ) return await run_snmp_check(config) async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None: open_incident = db.scalar( select(Incident).where( Incident.monitor_id == monitor.id, Incident.alert_rule_id == rule.id, Incident.status == "open", ) ) if monitor.status == "up": if open_incident is not None: open_incident.status = "resolved" open_incident.resolved_at = now open_incident.details = {**(open_incident.details or {}), "recovery_message": message} await self._send_incident_notifications(db, open_incident, monitor, "resolved", now) return recent_statuses = list( db.scalars( select(CheckResult.status) .where(CheckResult.monitor_id == monitor.id) .order_by(CheckResult.observed_at.desc()) .limit(rule.failure_threshold) ) ) threshold_met = len(recent_statuses) >= rule.failure_threshold and all(status != "up" for status in recent_statuses) if threshold_met and open_incident is None: if rule.cooldown_seconds > 0: latest_incident = db.scalar( select(Incident) .where( Incident.monitor_id == monitor.id, Incident.alert_rule_id == rule.id, ) .order_by(Incident.opened_at.desc()) .limit(1) ) if ( latest_incident is not None and self._as_utc(latest_incident.opened_at) + timedelta(seconds=rule.cooldown_seconds) > now ): return incident = Incident( asset_id=monitor.asset_id, monitor_id=monitor.id, alert_rule_id=rule.id, title=f"{monitor.name} is failing", severity=rule.severity, status="open", opened_at=now, details={"last_message": message, "failure_threshold": rule.failure_threshold}, ) db.add(incident) db.flush() await self._send_incident_notifications(db, incident, monitor, "opened", now) async def _send_incident_notifications( self, db: Session, incident: Incident, monitor: Monitor, event_type: str, now: datetime, ) -> None: state_key = "opened_sent_at" if event_type == "opened" else "resolved_sent_at" notification_state = dict((incident.details or {}).get("notification_state") or {}) if notification_state.get(state_key): return channels = db.scalars( select(NotificationChannel).where( NotificationChannel.is_enabled.is_(True), NotificationChannel.channel_type.in_(["generic_webhook", "webhook", "mattermost", "zoom", "zoom_team_chat"]), ) ).all() if not channels: return sent_channels: list[str] = [] for channel in channels: url = decrypt_secret(channel.encrypted_secret) if not url: logger.warning("Skipping notification channel %s because its secret cannot be decrypted", channel.id) continue try: await self._post_webhook( url, self._format_incident_message(incident, monitor, event_type), str((channel.settings or {}).get("username") or "OrbitalWard"), ) except httpx.HTTPError: logger.exception("Notification delivery failed for channel %s", channel.id) continue sent_channels.append(channel.name) if sent_channels: notification_state[state_key] = now.isoformat() history = list((incident.details or {}).get("notification_history") or []) history.append({"event": event_type, "sent_at": now.isoformat(), "channels": sent_channels}) incident.details = {**(incident.details or {}), "notification_state": notification_state, "notification_history": history} async def _post_webhook(self, url: str, message: str, username: str) -> None: async with httpx.AsyncClient(timeout=10) as client: response = await client.post(url, json={"username": username, "text": message}) response.raise_for_status() @staticmethod def _as_utc(value: datetime) -> datetime: if value.tzinfo is None: return value.replace(tzinfo=UTC) return value.astimezone(UTC) def _format_incident_message(self, incident: Incident, monitor: Monitor, event_type: str) -> str: if event_type == "resolved": title = f"RESOLVED: {monitor.name} recovered" body = [ title, "", f"Monitor: {monitor.name}", f"Target: {monitor.target}", f"Resolved: {incident.resolved_at or datetime.now(UTC)}", ] else: title = f"{incident.severity.upper()}: {incident.title}" body = [ title, "", f"Monitor: {monitor.name}", f"Target: {monitor.target}", f"Status: {monitor.status}", f"Started: {incident.opened_at}", ] last_message = (incident.details or {}).get("last_message") if last_message: body.append(f"Last response: {last_message}") body.extend(["", f"View in OrbitalWard: {settings.frontend_url}/incidents/{incident.id}"]) return "\n".join(str(line) for line in body)