Files
OrbitWard/worker/app/scheduler.py
T
2026-05-22 17:36:40 -06:00

210 lines
8.4 KiB
Python

import asyncio
import logging
from datetime import UTC, datetime, timedelta
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
import httpx
from app.collectors.website import WebsiteCheckConfig, run_website_check
from app.config import settings
from app.db import session_scope
from app.models import AlertRule, Asset, CheckResult, Incident, Monitor, NotificationChannel
from app.secrets import decrypt_secret
logger = logging.getLogger(__name__)
class Scheduler:
def __init__(self, poll_interval_seconds: int = 10) -> None:
self.poll_interval_seconds = poll_interval_seconds
self._stopped = asyncio.Event()
async def run(self) -> None:
logger.info("InfraPulse worker started for %s", settings.infrapulse_env)
while not self._stopped.is_set():
await self.tick()
try:
await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval_seconds)
except TimeoutError:
continue
async def tick(self) -> None:
try:
with session_scope() as db:
due_monitors = self._load_due_website_monitors(db)
for monitor in due_monitors:
await self._run_monitor(db, monitor)
db.commit()
except SQLAlchemyError:
logger.exception("Worker tick failed while talking to the database")
def stop(self) -> None:
self._stopped.set()
def _load_due_website_monitors(self, db: Session) -> list[Monitor]:
now = datetime.now(UTC)
monitors = db.scalars(select(Monitor).where(Monitor.monitor_type == "http").order_by(Monitor.id).limit(50)).all()
due: list[Monitor] = []
for monitor in monitors:
if monitor.last_checked_at is None:
due.append(monitor)
continue
next_due_at = monitor.last_checked_at + timedelta(seconds=monitor.interval_seconds)
if next_due_at <= now:
due.append(monitor)
return due
async def _run_monitor(self, db: Session, monitor: Monitor) -> None:
config = WebsiteCheckConfig(
url=monitor.target,
expected_status=int(monitor.config.get("expected_status", 200)),
expected_text=monitor.config.get("expected_text") or None,
unexpected_text=monitor.config.get("unexpected_text") or None,
timeout_seconds=float(monitor.config.get("timeout_seconds", 10)),
)
result = await run_website_check(config)
now = datetime.now(UTC)
monitor.status = result.status
monitor.last_checked_at = now
db.add(
CheckResult(
monitor_id=monitor.id,
status=result.status,
response_time_ms=result.response_time_ms,
message=result.message,
observed_at=now,
)
)
db.flush()
if monitor.asset_id is not None:
asset = db.get(Asset, monitor.asset_id)
if asset is not None:
asset.status = result.status
rules = db.scalars(select(AlertRule).where(AlertRule.monitor_id == monitor.id, AlertRule.is_enabled.is_(True))).all()
for rule in rules:
await self._evaluate_rule(db, monitor, rule, now, result.message)
logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms)
async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None:
open_incident = db.scalar(
select(Incident).where(
Incident.monitor_id == monitor.id,
Incident.alert_rule_id == rule.id,
Incident.status == "open",
)
)
if monitor.status == "up":
if open_incident is not None:
open_incident.status = "resolved"
open_incident.resolved_at = now
open_incident.details = {**(open_incident.details or {}), "recovery_message": message}
await self._send_incident_notifications(db, open_incident, monitor, "resolved", now)
return
recent_statuses = list(
db.scalars(
select(CheckResult.status)
.where(CheckResult.monitor_id == monitor.id)
.order_by(CheckResult.observed_at.desc())
.limit(rule.failure_threshold)
)
)
threshold_met = len(recent_statuses) >= rule.failure_threshold and all(status != "up" for status in recent_statuses)
if threshold_met and open_incident is None:
incident = Incident(
asset_id=monitor.asset_id,
monitor_id=monitor.id,
alert_rule_id=rule.id,
title=f"{monitor.name} is failing",
severity=rule.severity,
status="open",
opened_at=now,
details={"last_message": message, "failure_threshold": rule.failure_threshold},
)
db.add(incident)
db.flush()
await self._send_incident_notifications(db, incident, monitor, "opened", now)
async def _send_incident_notifications(
self,
db: Session,
incident: Incident,
monitor: Monitor,
event_type: str,
now: datetime,
) -> None:
state_key = "opened_sent_at" if event_type == "opened" else "resolved_sent_at"
notification_state = dict((incident.details or {}).get("notification_state") or {})
if notification_state.get(state_key):
return
channels = db.scalars(
select(NotificationChannel).where(
NotificationChannel.is_enabled.is_(True),
NotificationChannel.channel_type.in_(["generic_webhook", "webhook", "mattermost", "zoom", "zoom_team_chat"]),
)
).all()
if not channels:
return
sent_channels: list[str] = []
for channel in channels:
url = decrypt_secret(channel.encrypted_secret)
if not url:
logger.warning("Skipping notification channel %s because its secret cannot be decrypted", channel.id)
continue
try:
await self._post_webhook(
url,
self._format_incident_message(incident, monitor, event_type),
str((channel.settings or {}).get("username") or "InfraPulse"),
)
except httpx.HTTPError:
logger.exception("Notification delivery failed for channel %s", channel.id)
continue
sent_channels.append(channel.name)
if sent_channels:
notification_state[state_key] = now.isoformat()
history = list((incident.details or {}).get("notification_history") or [])
history.append({"event": event_type, "sent_at": now.isoformat(), "channels": sent_channels})
incident.details = {**(incident.details or {}), "notification_state": notification_state, "notification_history": history}
async def _post_webhook(self, url: str, message: str, username: str) -> None:
async with httpx.AsyncClient(timeout=10) as client:
response = await client.post(url, json={"username": username, "text": message})
response.raise_for_status()
def _format_incident_message(self, incident: Incident, monitor: Monitor, event_type: str) -> str:
if event_type == "resolved":
title = f"RESOLVED: {monitor.name} recovered"
body = [
title,
"",
f"Monitor: {monitor.name}",
f"Target: {monitor.target}",
f"Resolved: {incident.resolved_at or datetime.now(UTC)}",
]
else:
title = f"{incident.severity.upper()}: {incident.title}"
body = [
title,
"",
f"Monitor: {monitor.name}",
f"Target: {monitor.target}",
f"Status: {monitor.status}",
f"Started: {incident.opened_at}",
]
last_message = (incident.details or {}).get("last_message")
if last_message:
body.append(f"Last response: {last_message}")
body.extend(["", f"View in InfraPulse: {settings.frontend_url}/incidents/{incident.id}"])
return "\n".join(str(line) for line in body)