301 lines
12 KiB
Python
301 lines
12 KiB
Python
import asyncio
|
|
import logging
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.orm import Session
|
|
import httpx
|
|
|
|
from app.collectors.snmp import SnmpCheckConfig, SnmpCheckResult, run_snmp_check
|
|
from app.collectors.website import WebsiteCheckConfig, run_website_check
|
|
from app.collectors.network import PingCheckConfig, TcpCheckConfig, run_ping_check, run_tcp_check
|
|
from app.config import settings
|
|
from app.db import session_scope
|
|
from app.models import AlertRule, Asset, CheckResult, Credential, Incident, Metric, Monitor, NotificationChannel
|
|
from app.secrets import decrypt_secret
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, poll_interval_seconds: int = 10) -> None:
|
|
self.poll_interval_seconds = poll_interval_seconds
|
|
self._stopped = asyncio.Event()
|
|
|
|
async def run(self) -> None:
|
|
logger.info("OrbitWard worker started for %s", settings.orbitward_env)
|
|
while not self._stopped.is_set():
|
|
await self.tick()
|
|
try:
|
|
await asyncio.wait_for(self._stopped.wait(), timeout=self.poll_interval_seconds)
|
|
except TimeoutError:
|
|
continue
|
|
|
|
async def tick(self) -> None:
|
|
try:
|
|
with session_scope() as db:
|
|
due_monitors = self._load_due_monitors(db)
|
|
for monitor in due_monitors:
|
|
await self._run_monitor(db, monitor)
|
|
db.commit()
|
|
except SQLAlchemyError:
|
|
logger.exception("Worker tick failed while talking to the database")
|
|
|
|
def stop(self) -> None:
|
|
self._stopped.set()
|
|
|
|
def _load_due_monitors(self, db: Session) -> list[Monitor]:
|
|
now = datetime.now(UTC)
|
|
monitors = db.scalars(
|
|
select(Monitor).where(Monitor.monitor_type.in_(["http", "ping", "tcp", "snmp"])).order_by(Monitor.id).limit(50)
|
|
).all()
|
|
due: list[Monitor] = []
|
|
for monitor in monitors:
|
|
if monitor.last_checked_at is None:
|
|
due.append(monitor)
|
|
continue
|
|
next_due_at = monitor.last_checked_at + timedelta(seconds=monitor.interval_seconds)
|
|
if next_due_at <= now:
|
|
due.append(monitor)
|
|
return due
|
|
|
|
async def _run_monitor(self, db: Session, monitor: Monitor) -> None:
|
|
result = await self._collect_monitor_result(db, monitor)
|
|
now = datetime.now(UTC)
|
|
|
|
monitor.status = result.status
|
|
monitor.last_checked_at = now
|
|
db.add(
|
|
CheckResult(
|
|
monitor_id=monitor.id,
|
|
status=result.status,
|
|
response_time_ms=result.response_time_ms,
|
|
message=result.message,
|
|
observed_at=now,
|
|
)
|
|
)
|
|
db.flush()
|
|
|
|
for metric in getattr(result, "metrics", []):
|
|
db.add(
|
|
Metric(
|
|
monitor_id=monitor.id,
|
|
name=metric.name,
|
|
value=metric.value,
|
|
unit=metric.unit,
|
|
observed_at=now,
|
|
)
|
|
)
|
|
db.flush()
|
|
|
|
if monitor.asset_id is not None:
|
|
asset = db.get(Asset, monitor.asset_id)
|
|
if asset is not None:
|
|
asset.status = result.status
|
|
|
|
rules = db.scalars(select(AlertRule).where(AlertRule.monitor_id == monitor.id, AlertRule.is_enabled.is_(True))).all()
|
|
for rule in rules:
|
|
await self._evaluate_rule(db, monitor, rule, now, result.message)
|
|
|
|
logger.info("Checked %s: %s (%s ms)", monitor.name, result.status, result.response_time_ms)
|
|
|
|
async def _collect_monitor_result(self, db: Session, monitor: Monitor):
|
|
if monitor.monitor_type == "http":
|
|
config = WebsiteCheckConfig(
|
|
url=monitor.target,
|
|
expected_status=int(monitor.config.get("expected_status", 200)),
|
|
expected_text=monitor.config.get("expected_text") or None,
|
|
unexpected_text=monitor.config.get("unexpected_text") or None,
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 10)),
|
|
check_tls_expiry=bool(monitor.config.get("check_tls_expiry", False)),
|
|
tls_warning_days=int(monitor.config.get("tls_warning_days", 30)),
|
|
)
|
|
return await run_website_check(config)
|
|
|
|
if monitor.monitor_type == "ping":
|
|
config = PingCheckConfig(
|
|
host=monitor.target,
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
|
|
)
|
|
return await run_ping_check(config)
|
|
|
|
if monitor.monitor_type == "tcp":
|
|
config = TcpCheckConfig(
|
|
host=str(monitor.config.get("host") or monitor.target),
|
|
port=int(monitor.config.get("port")),
|
|
timeout_seconds=float(monitor.config.get("timeout_seconds", 5)),
|
|
)
|
|
return await run_tcp_check(config)
|
|
|
|
if monitor.monitor_type == "snmp":
|
|
return await self._collect_snmp_monitor_result(db, monitor)
|
|
|
|
raise ValueError(f"Unsupported monitor type: {monitor.monitor_type}")
|
|
|
|
async def _collect_snmp_monitor_result(self, db: Session, monitor: Monitor) -> SnmpCheckResult:
|
|
profile_id = monitor.config.get("credential_profile_id")
|
|
if not isinstance(profile_id, int):
|
|
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile is not configured")
|
|
|
|
profile = db.get(Credential, profile_id)
|
|
if profile is None or profile.credential_type != "snmp":
|
|
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile was not found")
|
|
|
|
community = decrypt_secret(profile.encrypted_secret)
|
|
if not community:
|
|
return SnmpCheckResult(status="down", response_time_ms=None, message="SNMP credential profile has no usable community string")
|
|
|
|
extra = dict(profile.extra or {})
|
|
config = SnmpCheckConfig(
|
|
host=monitor.target,
|
|
community=community,
|
|
item_id=str(monitor.config.get("item_id") or ""),
|
|
item_type=str(monitor.config.get("item_type") or ""),
|
|
label=monitor.config.get("label") if isinstance(monitor.config.get("label"), str) else None,
|
|
unit=monitor.config.get("unit") if isinstance(monitor.config.get("unit"), str) else None,
|
|
port=int(extra.get("port") or 161),
|
|
timeout_seconds=float(extra.get("timeout_seconds") or 5),
|
|
retries=int(extra.get("retries") or 1),
|
|
)
|
|
return await run_snmp_check(config)
|
|
|
|
async def _evaluate_rule(self, db: Session, monitor: Monitor, rule: AlertRule, now: datetime, message: str) -> None:
|
|
open_incident = db.scalar(
|
|
select(Incident).where(
|
|
Incident.monitor_id == monitor.id,
|
|
Incident.alert_rule_id == rule.id,
|
|
Incident.status == "open",
|
|
)
|
|
)
|
|
|
|
if monitor.status == "up":
|
|
if open_incident is not None:
|
|
open_incident.status = "resolved"
|
|
open_incident.resolved_at = now
|
|
open_incident.details = {**(open_incident.details or {}), "recovery_message": message}
|
|
await self._send_incident_notifications(db, open_incident, monitor, "resolved", now)
|
|
return
|
|
|
|
recent_statuses = list(
|
|
db.scalars(
|
|
select(CheckResult.status)
|
|
.where(CheckResult.monitor_id == monitor.id)
|
|
.order_by(CheckResult.observed_at.desc())
|
|
.limit(rule.failure_threshold)
|
|
)
|
|
)
|
|
threshold_met = len(recent_statuses) >= rule.failure_threshold and all(status != "up" for status in recent_statuses)
|
|
if threshold_met and open_incident is None:
|
|
if rule.cooldown_seconds > 0:
|
|
latest_incident = db.scalar(
|
|
select(Incident)
|
|
.where(
|
|
Incident.monitor_id == monitor.id,
|
|
Incident.alert_rule_id == rule.id,
|
|
)
|
|
.order_by(Incident.opened_at.desc())
|
|
.limit(1)
|
|
)
|
|
if (
|
|
latest_incident is not None
|
|
and self._as_utc(latest_incident.opened_at) + timedelta(seconds=rule.cooldown_seconds) > now
|
|
):
|
|
return
|
|
|
|
incident = Incident(
|
|
asset_id=monitor.asset_id,
|
|
monitor_id=monitor.id,
|
|
alert_rule_id=rule.id,
|
|
title=f"{monitor.name} is failing",
|
|
severity=rule.severity,
|
|
status="open",
|
|
opened_at=now,
|
|
details={"last_message": message, "failure_threshold": rule.failure_threshold},
|
|
)
|
|
db.add(incident)
|
|
db.flush()
|
|
await self._send_incident_notifications(db, incident, monitor, "opened", now)
|
|
|
|
async def _send_incident_notifications(
|
|
self,
|
|
db: Session,
|
|
incident: Incident,
|
|
monitor: Monitor,
|
|
event_type: str,
|
|
now: datetime,
|
|
) -> None:
|
|
state_key = "opened_sent_at" if event_type == "opened" else "resolved_sent_at"
|
|
notification_state = dict((incident.details or {}).get("notification_state") or {})
|
|
if notification_state.get(state_key):
|
|
return
|
|
|
|
channels = db.scalars(
|
|
select(NotificationChannel).where(
|
|
NotificationChannel.is_enabled.is_(True),
|
|
NotificationChannel.channel_type.in_(["generic_webhook", "webhook", "mattermost", "zoom", "zoom_team_chat"]),
|
|
)
|
|
).all()
|
|
if not channels:
|
|
return
|
|
|
|
sent_channels: list[str] = []
|
|
for channel in channels:
|
|
url = decrypt_secret(channel.encrypted_secret)
|
|
if not url:
|
|
logger.warning("Skipping notification channel %s because its secret cannot be decrypted", channel.id)
|
|
continue
|
|
try:
|
|
await self._post_webhook(
|
|
url,
|
|
self._format_incident_message(incident, monitor, event_type),
|
|
str((channel.settings or {}).get("username") or "OrbitWard"),
|
|
)
|
|
except httpx.HTTPError:
|
|
logger.exception("Notification delivery failed for channel %s", channel.id)
|
|
continue
|
|
sent_channels.append(channel.name)
|
|
|
|
if sent_channels:
|
|
notification_state[state_key] = now.isoformat()
|
|
history = list((incident.details or {}).get("notification_history") or [])
|
|
history.append({"event": event_type, "sent_at": now.isoformat(), "channels": sent_channels})
|
|
incident.details = {**(incident.details or {}), "notification_state": notification_state, "notification_history": history}
|
|
|
|
async def _post_webhook(self, url: str, message: str, username: str) -> None:
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
response = await client.post(url, json={"username": username, "text": message})
|
|
response.raise_for_status()
|
|
|
|
@staticmethod
|
|
def _as_utc(value: datetime) -> datetime:
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=UTC)
|
|
return value.astimezone(UTC)
|
|
|
|
def _format_incident_message(self, incident: Incident, monitor: Monitor, event_type: str) -> str:
|
|
if event_type == "resolved":
|
|
title = f"RESOLVED: {monitor.name} recovered"
|
|
body = [
|
|
title,
|
|
"",
|
|
f"Monitor: {monitor.name}",
|
|
f"Target: {monitor.target}",
|
|
f"Resolved: {incident.resolved_at or datetime.now(UTC)}",
|
|
]
|
|
else:
|
|
title = f"{incident.severity.upper()}: {incident.title}"
|
|
body = [
|
|
title,
|
|
"",
|
|
f"Monitor: {monitor.name}",
|
|
f"Target: {monitor.target}",
|
|
f"Status: {monitor.status}",
|
|
f"Started: {incident.opened_at}",
|
|
]
|
|
last_message = (incident.details or {}).get("last_message")
|
|
if last_message:
|
|
body.append(f"Last response: {last_message}")
|
|
body.extend(["", f"View in OrbitWard: {settings.frontend_url}/incidents/{incident.id}"])
|
|
return "\n".join(str(line) for line in body)
|