16932957b2
Adds ping and TCP monitor creation APIs, worker collectors, network checks UI, dashboard monitor status support, and progress documentation.
236 lines
9.6 KiB
Python
236 lines
9.6 KiB
Python
import asyncio
|
|
import logging
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.orm import Session
|
|
import httpx
|
|
|
|
from app.collectors.website import WebsiteCheckConfig, run_website_check
|
|
from app.collectors.network import PingCheckConfig, TcpCheckConfig, run_ping_check, run_tcp_check
|
|
from app.config import settings
|
|
from app.db import session_scope
|
|
from app.models import AlertRule, Asset, CheckResult, Incident, Monitor, NotificationChannel
|
|
from app.secrets import decrypt_secret
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, poll_interval_seconds: int = 10) -> None:
|
|
self.poll_interval_seconds = poll_interval_seconds
|
|
self._stopped = asyncio.Event()
|
|
|
|
async def run(self) -> None:
|
|
logger.info("OrbitalWard worker started for %s", settings.orbitalward_env)
|
|
while not self._stopped.is_set():
|
|
await self.tick()
|
|
try:
|
|
await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval_seconds)
|
|
except TimeoutError:
|
|
continue
|
|
|
|
async def tick(self) -> None:
|
|
try:
|
|
with session_scope() as db:
|
|
due_monitors = self._load_due_monitors(db)
|
|
for monitor in due_monitors:
|
|
await self._run_monitor(db, monitor)
|
|
db.commit()
|
|
except SQLAlchemyError:
|
|
logger.exception("Worker tick failed while talking to the database")
|
|
|
|
def stop(self) -> None:
|
|
self._stopped.set()
|
|
|
|
def _load_due_monitors(self, db: Session) -> list[Monitor]:
|
|
now = datetime.now(UTC)
|
|
monitors = db.scalars(
|
|
select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp"])).order_by(Monitor.id).limit(50)
|
|
).all()
|
|
due: list[Monitor] = []
|
|
for monitor in monitors:
|
|
if monitor.last_checked_at is None:
|
|
due.append(monitor)
|
|
continue
|
|
next_due_at = monitor.last_checked_at + timedelta(seconds=monitor.interval_seconds)
|
|
if next_due_at <= now:
|
|
due.append(monitor)
|
|
return due
|
|
|
|
async def _run_monitor(self, db: Session, monitor: Monitor) -> None:
|
|
result = await self._collect_monitor_result(monitor)
|
|
now = datetime.now(UTC)
|
|
|
|
monitor.status = result.status
|
|
monitor.last_checked_at = now
|
|
db.add(
|
|
CheckResult(
|
|
monitor_id=monitor.id,
|
|
status=result.status,
|
|
response_time_ms=result.response_time_ms,
|
|
message=result.message,
|
|
observed_at=now,
|
|
)
|
|
)
|
|
db.flush()
|
|
|
|
if monitor.asset_id is not None:
|
|
asset = db.get(Asset, monitor.asset_id)
|
|
if asset is not None:
|
|
asset.status = result.status
|
|
|
|
rules = db.scalars(select(AlertRule).where(AlertRule.monitor_id == monitor.id, AlertRule.is_enabled.is_(True))).all()
|
|
for rule in rules:
|
|
await self._evaluate_rule(db, monitor, rule, now, result.message)
|
|
|
|
logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms)
|
|
|
|
async def _collect_monitor_result(self, monitor: Monitor):
|
|
if monitor.monitor_type == "http":
|
|
config = WebsiteCheckConfig(
|
|
url=monitor.target,
|
|
expected_status=int(monitor.config.get("expected_status", 200)),
|
|
expected_text=monitor.config.get("expected_text") or None,
|
|
unexpected_text=monitor.config.get("unexpected_text") or None,
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 10)),
|
|
check_tls_expiry=bool(monitor.config.get("check_tls_expiry", False)),
|
|
tls_warning_days=int(monitor.config.get("tls_warning_days", 30)),
|
|
)
|
|
return await run_website_check(config)
|
|
|
|
if monitor.monitor_type == "ping":
|
|
config = PingCheckConfig(
|
|
host=monitor.target,
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
|
|
)
|
|
return await run_ping_check(config)
|
|
|
|
if monitor.monitor_type == "tcp":
|
|
config = TcpCheckConfig(
|
|
host=str(monitor.config.get("host") or monitor.target),
|
|
port=int(monitor.config.get("port")),
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
|
|
)
|
|
return await run_tcp_check(config)
|
|
|
|
raise ValueError(f"Unsupported monitor type: {monitor.monitor_type}")
|
|
|
|
async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None:
|
|
open_incident = db.scalar(
|
|
select(Incident).where(
|
|
Incident.monitor_id == monitor.id,
|
|
Incident.alert_rule_id == rule.id,
|
|
Incident.status == "open",
|
|
)
|
|
)
|
|
|
|
if monitor.status == "up":
|
|
if open_incident is not None:
|
|
open_incident.status = "resolved"
|
|
open_incident.resolved_at = now
|
|
open_incident.details = {**(open_incident.details or {}), "recovery_message": message}
|
|
await self._send_incident_notifications(db, open_incident, monitor, "resolved", now)
|
|
return
|
|
|
|
recent_statuses = list(
|
|
db.scalars(
|
|
select(CheckResult.status)
|
|
.where(CheckResult.monitor_id == monitor.id)
|
|
.order_by(CheckResult.observed_at.desc())
|
|
.limit(rule.failure_threshold)
|
|
)
|
|
)
|
|
threshold_met = len(recent_statuses) >= rule.failure_threshold and all(status != "up" for status in recent_statuses)
|
|
if threshold_met and open_incident is None:
|
|
incident = Incident(
|
|
asset_id=monitor.asset_id,
|
|
monitor_id=monitor.id,
|
|
alert_rule_id=rule.id,
|
|
title=f"{monitor.name} is failing",
|
|
severity=rule.severity,
|
|
status="open",
|
|
opened_at=now,
|
|
details={"last_message": message, "failure_threshold": rule.failure_threshold},
|
|
)
|
|
db.add(incident)
|
|
db.flush()
|
|
await self._send_incident_notifications(db, incident, monitor, "opened", now)
|
|
|
|
async def _send_incident_notifications(
|
|
self,
|
|
db: Session,
|
|
incident: Incident,
|
|
monitor: Monitor,
|
|
event_type: str,
|
|
now: datetime,
|
|
) -> None:
|
|
state_key = "opened_sent_at" if event_type == "opened" else "resolved_sent_at"
|
|
notification_state = dict((incident.details or {}).get("notification_state") or {})
|
|
if notification_state.get(state_key):
|
|
return
|
|
|
|
channels = db.scalars(
|
|
select(NotificationChannel).where(
|
|
NotificationChannel.is_enabled.is_(True),
|
|
NotificationChannel.channel_type.in_(["generic_webhook", "webhook", "mattermost", "zoom", "zoom_team_chat"]),
|
|
)
|
|
).all()
|
|
if not channels:
|
|
return
|
|
|
|
sent_channels: list[str] = []
|
|
for channel in channels:
|
|
url = decrypt_secret(channel.encrypted_secret)
|
|
if not url:
|
|
logger.warning("Skipping notification channel %s because its secret cannot be decrypted", channel.id)
|
|
continue
|
|
try:
|
|
await self._post_webhook(
|
|
url,
|
|
self._format_incident_message(incident, monitor, event_type),
|
|
str((channel.settings or {}).get("username") or "OrbitalWard"),
|
|
)
|
|
except httpx.HTTPError:
|
|
logger.exception("Notification delivery failed for channel %s", channel.id)
|
|
continue
|
|
sent_channels.append(channel.name)
|
|
|
|
if sent_channels:
|
|
notification_state[state_key] = now.isoformat()
|
|
history = list((incident.details or {}).get("notification_history") or [])
|
|
history.append({"event": event_type, "sent_at": now.isoformat(), "channels": sent_channels})
|
|
incident.details = {**(incident.details or {}), "notification_state": notification_state, "notification_history": history}
|
|
|
|
async def _post_webhook(self, url: str, message: str, username: str) -> None:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.post(url, json={"username": username, "text": message})
|
|
response.raise_for_status()
|
|
|
|
def _format_incident_message(self, incident: Incident, monitor: Monitor, event_type: str) -> str:
|
|
if event_type == "resolved":
|
|
title = f"RESOLVED: {monitor.name} recovered"
|
|
body = [
|
|
title,
|
|
"",
|
|
f"Monitor: {monitor.name}",
|
|
f"Target: {monitor.target}",
|
|
f"Resolved: {incident.resolved_at or datetime.now(UTC)}",
|
|
]
|
|
else:
|
|
title = f"{incident.severity.upper()}: {incident.title}"
|
|
body = [
|
|
title,
|
|
"",
|
|
f"Monitor: {monitor.name}",
|
|
f"Target: {monitor.target}",
|
|
f"Status: {monitor.status}",
|
|
f"Started: {incident.opened_at}",
|
|
]
|
|
last_message = (incident.details or {}).get("last_message")
|
|
if last_message:
|
|
body.append(f"Last response: {last_message}")
|
|
body.extend(["", f"View in OrbitalWard: {settings.frontend_url}/incidents/{incident.id}"])
|
|
return "\n".join(str(line) for line in body)
|