heartwood every commit a ring
14.8 KB raw
use crate::alerts;
use crate::db::now_ms;
use crate::models::PropertyRow;
use anyhow::{anyhow, Context};
use chrono::Timelike;
use hickory_resolver::TokioAsyncResolver;
use rustls::pki_types::ServerName;
use serde_json::json;
use sqlx::SqlitePool;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use url::Url;
use uuid::Uuid;

const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 \
                         (KHTML, like Gecko) Chrome/102.0.5005.115 Safari/537.36 Status/2.0.0";
const HTTP_TIMEOUT_SECS: u64 = 10;
const MAX_REDIRECTS: usize = 5;

/// Phase-by-phase timings for a single HTTP probe. `None` means the phase
/// didn't run (the probe errored before reaching it). `total_ms` is
/// wall-clock first-hop end-to-end and is also written to
/// `checks.response_ms` for backward compat with the alert email avg.
/// Fields are `Option<i64>` so that pre-rewrite rows (migration 0002 left
/// them NULL) keep deserializing.
#[derive(Debug, Default, Clone, Copy)]
pub struct PhaseTimings {
    pub dns_ms: Option<i64>,
    pub tcp_ms: Option<i64>,
    pub tls_ms: Option<i64>,
    pub ttfb_ms: Option<i64>,
    pub total_ms: i64,
}

struct ProbeOutcome {
    status_code: i64,
    headers_json: String,
    timings: PhaseTimings,
}

/// One-hop result with everything we need to decide whether to follow a
/// redirect.
struct HopResult {
    status_code: i64,
    headers: BTreeMap<String, String>,
    raw_headers_json: String,
    timings: PhaseTimings,
}

/// Build a fresh rustls config per probe. Keeps the "fresh client per
/// probe = real handshake cost" invariant. ALPN-pinned to `h2` only:
/// servers that don't speak HTTP/2 will fail the handshake (mapped to
/// 526), which matches the project's "no HTTP/1.1" stance.
fn tls_config_h2() -> Arc<rustls::ClientConfig> {
    let mut roots = rustls::RootCertStore::empty();
    roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
    let mut cfg = rustls::ClientConfig::builder()
        .with_root_certificates(roots)
        .with_no_client_auth();
    cfg.alpn_protocols = vec![b"h2".to_vec()];
    Arc::new(cfg)
}

/// Round a duration to whole milliseconds, but report a sub-ms phase as
/// 1 rather than 0. The Linux kernel routes traffic destined for the
/// host's own public IP via `lo`, so loopback TCP/handshake phases
/// genuinely take 200-500 microseconds, which `as_millis()` truncates to
/// zero. A `0` in the chart reads as "this phase didn't happen" rather
/// than "this phase was instant", so floor it at 1 ms when the phase did
/// run. Total is unaffected: any real probe takes well over 1 ms.
fn elapsed_ms_atleast1(d: std::time::Duration) -> i64 {
    let ms = d.as_millis() as i64;
    if ms > 0 || d.is_zero() {
        ms
    } else {
        1
    }
}

fn looks_like_ssl_error(e: &anyhow::Error) -> bool {
    let s = format!("{e:?}").to_lowercase();
    s.contains("certificate")
        || s.contains("invalidcertificate")
        || s.contains("tls")
        || s.contains("handshake")
}

/// Run a single HTTP check and persist the result. Maps SSL errors to 526
/// (Cloudflare convention) and timeouts to 408 so the dashboard can show
/// failure reasons without piping arbitrary error messages.
pub async fn run_check(pool: &SqlitePool, prop: &PropertyRow) -> sqlx::Result<i64> {
    let outcome = match probe_with_redirects(&prop.url).await {
        Ok(o) => o,
        Err(e) => {
            let code = if looks_like_ssl_error(&e) { 526 } else { 408 };
            ProbeOutcome {
                status_code: code,
                headers_json: "{}".to_string(),
                timings: PhaseTimings {
                    total_ms: (HTTP_TIMEOUT_SECS as i64) * 1000,
                    ..PhaseTimings::default()
                },
            }
        }
    };

    let id = prop.id.clone();
    sqlx::query(
        "INSERT INTO checks (property_id, status_code, response_ms, headers, dns_ms, tcp_ms, tls_ms, ttfb_ms, created_at) \
         VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
    )
    .bind(&id)
    .bind(outcome.status_code)
    .bind(outcome.timings.total_ms)
    .bind(&outcome.headers_json)
    .bind(outcome.timings.dns_ms)
    .bind(outcome.timings.tcp_ms)
    .bind(outcome.timings.tls_ms)
    .bind(outcome.timings.ttfb_ms)
    .bind(now_ms())
    .execute(pool)
    .await?;

    Ok(outcome.status_code)
}

/// First-hop with full phase timings, then follow up to MAX_REDIRECTS
/// 3xx hops to discover the final status code (so the alert state machine
/// keeps working when a property uses an http→https or apex→www
/// redirect). Phase timings always reflect the first hop only — that's
/// the latency a fresh visitor pays before being redirected, and it's the
/// only number that's meaningful when later hops live on different
/// servers/domains.
async fn probe_with_redirects(url_str: &str) -> anyhow::Result<ProbeOutcome> {
    let url = Url::parse(url_str).context("invalid URL")?;

    let outer = tokio::time::timeout(
        Duration::from_secs(HTTP_TIMEOUT_SECS),
        async {
            let mut current = url.clone();
            let first = phased_hop(&current).await?;
            let first_timings = first.timings;
            let mut status = first.status_code;
            let mut headers_json = first.raw_headers_json;
            let mut headers = first.headers;
            let mut hops = 0usize;

            while is_redirect(status) && hops < MAX_REDIRECTS {
                let Some(loc) = headers.get("location").cloned() else { break };
                let Ok(next) = current.join(&loc) else { break };
                current = next;
                let hop = match phased_hop(&current).await {
                    Ok(h) => h,
                    Err(_) => break,
                };
                status = hop.status_code;
                headers_json = hop.raw_headers_json;
                headers = hop.headers;
                hops += 1;
            }

            Ok::<_, anyhow::Error>(ProbeOutcome {
                status_code: status,
                headers_json,
                timings: first_timings,
            })
        },
    )
    .await;

    match outer {
        Ok(Ok(o)) => Ok(o),
        Ok(Err(e)) => Err(e),
        Err(_) => Err(anyhow!("timeout after {HTTP_TIMEOUT_SECS}s")),
    }
}

fn is_redirect(code: i64) -> bool {
    matches!(code, 301 | 302 | 303 | 307 | 308)
}

async fn phased_hop(url: &Url) -> anyhow::Result<HopResult> {
    let host = url.host_str().context("URL missing host")?.to_string();
    let port = url.port_or_known_default().context("URL missing port")?;
    let path_q = match url.query() {
        Some(q) if !q.is_empty() => format!("{}?{}", url.path(), q),
        _ => url.path().to_string(),
    };
    let path_q = if path_q.is_empty() { "/".to_string() } else { path_q };
    let is_https = url.scheme() == "https";

    let total_start = Instant::now();

    let dns_start = Instant::now();
    let resolver =
        TokioAsyncResolver::tokio_from_system_conf().context("creating dns resolver")?;
    let lookup = resolver.lookup_ip(host.as_str()).await.context("dns lookup")?;
    let ip = lookup
        .iter()
        .next()
        .ok_or_else(|| anyhow!("no addresses for {host}"))?;
    let dns_ms = elapsed_ms_atleast1(dns_start.elapsed());
    let addr = SocketAddr::new(ip, port);

    let tcp_start = Instant::now();
    let tcp = TcpStream::connect(addr).await.context("tcp connect")?;
    tcp.set_nodelay(true).ok();
    let tcp_ms = elapsed_ms_atleast1(tcp_start.elapsed());

    if !is_https {
        // h2 over plain TCP (h2c with prior knowledge) is rare in the
        // wild, and the project is HTTP/2-only, so reject http:// URLs
        // explicitly rather than silently downgrading.
        return Err(anyhow!("plain HTTP not supported; use https:// (HTTP/2 only)"));
    }

    let tls_start = Instant::now();
    let server_name =
        ServerName::try_from(host.clone()).context("invalid TLS server name")?;
    let connector = TlsConnector::from(tls_config_h2());
    let tls_stream = connector
        .connect(server_name, tcp)
        .await
        .context("tls handshake")?;
    let tls_ms = elapsed_ms_atleast1(tls_start.elapsed());

    let (status, headers, raw_headers_json, ttfb_ms) =
        h2_request(tls_stream, &host, &path_q).await?;

    let total_ms = total_start.elapsed().as_millis() as i64;

    Ok(HopResult {
        status_code: status,
        headers,
        raw_headers_json,
        timings: PhaseTimings {
            dns_ms: Some(dns_ms),
            tcp_ms: Some(tcp_ms),
            tls_ms: Some(tls_ms),
            ttfb_ms: Some(ttfb_ms),
            total_ms,
        },
    })
}

/// Run an HTTP/2 GET over an established TLS stream and return
/// (status_code, headers, headers_json, ttfb_ms). TTFB is measured from
/// the start of the h2 client handshake (SETTINGS exchange) to the
/// arrival of the response HEADERS frame, so it includes h2 protocol
/// setup; the user-facing chart treats it as "everything between secure
/// connection ready and first server byte", which matches curl's
/// `time_starttransfer` minus `time_appconnect`.
async fn h2_request(
    tls_stream: tokio_rustls::client::TlsStream<TcpStream>,
    host: &str,
    path: &str,
) -> anyhow::Result<(i64, BTreeMap<String, String>, String, i64)> {
    let ttfb_start = Instant::now();
    let (sr, connection) = h2::client::handshake(tls_stream)
        .await
        .context("h2 handshake")?;
    // h2 needs someone to drive the connection's I/O loop. Spawn a task
    // that lives just as long as this probe; we abort it on the way out.
    let conn_task = tokio::spawn(async move {
        let _ = connection.await;
    });

    let mut sr = sr.ready().await.context("h2 send-request ready")?;
    let req = http::Request::builder()
        .method("GET")
        .uri(format!("https://{host}{path}"))
        .header("user-agent", USER_AGENT)
        .header("accept", "*/*")
        .body(())
        .context("h2 request build")?;
    let (rsp_fut, _send_stream) = sr.send_request(req, true).context("h2 send_request")?;
    let rsp = rsp_fut.await.context("h2 response")?;
    let ttfb_ms = elapsed_ms_atleast1(ttfb_start.elapsed());

    let status = rsp.status().as_u16() as i64;
    let mut headers = BTreeMap::new();
    for (k, v) in rsp.headers().iter() {
        if let Ok(s) = v.to_str() {
            headers.insert(k.as_str().to_lowercase(), s.to_string());
        }
    }
    let raw_headers_json = serde_json::Value::Object(
        headers.iter().map(|(k, v)| (k.clone(), json!(v))).collect(),
    )
    .to_string();

    drop(sr);
    conn_task.abort();
    Ok((status, headers, raw_headers_json, ttfb_ms))
}

/// Run a check, persist it, then advance the alert state machine and fire
/// notifications on transitions.
pub async fn process_check(
    pool: &SqlitePool,
    config: &crate::Config,
    prop: &PropertyRow,
) -> anyhow::Result<()> {
    let status_code = run_check(pool, prop).await?;
    advance_alert_state(pool, config, prop, status_code).await?;
    Ok(())
}

/// State transitions:
///   UP -> DOWN: requires 2 consecutive non-200 checks (avoids false positives)
///   DOWN -> UP: immediate on 200
/// Commits state inside a transaction *before* firing notifications so a
/// crash mid-alert can't cause duplicate sends.
async fn advance_alert_state(
    pool: &SqlitePool,
    config: &crate::Config,
    prop: &PropertyRow,
    status_code: i64,
) -> anyhow::Result<()> {
    let is_up = status_code == 200;
    let mut tx = pool.begin().await?;

    let row: Option<(String,)> =
        sqlx::query_as("SELECT alert_state FROM properties WHERE id = ?")
            .bind(prop.id.clone())
            .fetch_optional(&mut *tx)
            .await?;
    let Some((current_state,)) = row else {
        return Ok(());
    };

    let mut transition: Option<&str> = None;
    if is_up && current_state == "down" {
        transition = Some("recovery");
    } else if !is_up && current_state == "up" {
        // Need 2 consecutive non-200s. We just inserted one; check the prior.
        let recent: Vec<(i64,)> = sqlx::query_as(
            "SELECT status_code FROM checks WHERE property_id = ? ORDER BY created_at DESC LIMIT 2",
        )
        .bind(prop.id.clone())
        .fetch_all(&mut *tx)
        .await?;
        if recent.len() >= 2 && recent[0].0 != 200 && recent[1].0 != 200 {
            transition = Some("down");
        }
    }

    if let Some(kind) = transition {
        let new_state = if kind == "recovery" { "up" } else { "down" };
        sqlx::query(
            "UPDATE properties SET alert_state = ?, last_alert_sent = ?, updated_at = ? WHERE id = ?",
        )
        .bind(new_state)
        .bind(now_ms())
        .bind(now_ms())
        .bind(prop.id.clone())
        .execute(&mut *tx)
        .await?;
    }
    tx.commit().await?;

    if let Some(kind) = transition {
        let avg_response_time = recent_avg_response_ms(pool, &prop.id).await.unwrap_or(0);
        let ctx = alerts::EmailContext {
            id: prop.uuid(),
            name: prop.name(),
            url: prop.url.clone(),
            current_status: status_code,
            avg_response_time,
        };
        let alert_email = config.alert_email.clone();
        let webhook = config.discord_webhook_url.clone();
        let base_url = config.base_url.clone();
        let url_for_log = prop.url.clone();
        // Fire-and-forget: alerts don't block the scheduler tick.
        tokio::spawn(async move {
            if let Err(e) =
                alerts::fire(kind, &ctx, &base_url, alert_email.as_deref(), webhook.as_deref()).await
            {
                tracing::warn!("alert dispatch failed for {url_for_log}: {e}");
            }
        });
    }
    Ok(())
}

/// Average response time across the most recent 31 checks. Mirrors the
/// dashboard's "rolling avg" tile so the email matches what the user sees.
async fn recent_avg_response_ms(pool: &SqlitePool, id: &[u8]) -> sqlx::Result<i64> {
    let rows: Vec<(i64,)> = sqlx::query_as(
        "SELECT response_ms FROM checks WHERE property_id = ? ORDER BY created_at DESC LIMIT 31",
    )
    .bind(id.to_vec())
    .fetch_all(pool)
    .await?;
    if rows.is_empty() {
        return Ok(0);
    }
    let sum: i64 = rows.iter().map(|r| r.0).sum();
    Ok(sum / rows.len() as i64)
}

/// Compute the next due time aligned to a 3-minute boundary, matching Django.
pub fn next_3min_boundary() -> i64 {
    let now = chrono::Utc::now();
    let minute = now.minute() as i64;
    let aligned_min = (minute / 3) * 3;
    let aligned = now
        .with_minute(aligned_min as u32)
        .and_then(|d| d.with_second(0))
        .and_then(|d| d.with_nanosecond(0))
        .unwrap_or(now);
    (aligned + chrono::Duration::minutes(3)).timestamp_millis()
}

#[allow(dead_code)]
pub async fn property_id_to_uuid(id_blob: &[u8]) -> Uuid {
    Uuid::from_slice(id_blob).unwrap_or(Uuid::nil())
}