14.8 KB
raw
use crate::alerts;
use crate::db::now_ms;
use crate::models::PropertyRow;
use anyhow::{anyhow, Context};
use chrono::Timelike;
use hickory_resolver::TokioAsyncResolver;
use rustls::pki_types::ServerName;
use serde_json::json;
use sqlx::SqlitePool;
use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use url::Url;
use uuid::Uuid;
const USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 \
(KHTML, like Gecko) Chrome/102.0.5005.115 Safari/537.36 Status/2.0.0";
const HTTP_TIMEOUT_SECS: u64 = 10;
const MAX_REDIRECTS: usize = 5;
/// Phase-by-phase timings for a single HTTP probe. `None` means the phase
/// didn't run (the probe errored before reaching it). `total_ms` is
/// wall-clock first-hop end-to-end and is also written to
/// `checks.response_ms` for backward compat with the alert email avg.
/// Fields are `Option<i64>` so that pre-rewrite rows (migration 0002 left
/// them NULL) keep deserializing.
#[derive(Debug, Default, Clone, Copy)]
pub struct PhaseTimings {
pub dns_ms: Option<i64>,
pub tcp_ms: Option<i64>,
pub tls_ms: Option<i64>,
pub ttfb_ms: Option<i64>,
pub total_ms: i64,
}
struct ProbeOutcome {
status_code: i64,
headers_json: String,
timings: PhaseTimings,
}
/// One-hop result with everything we need to decide whether to follow a
/// redirect.
struct HopResult {
status_code: i64,
headers: BTreeMap<String, String>,
raw_headers_json: String,
timings: PhaseTimings,
}
/// Build a fresh rustls config per probe. Keeps the "fresh client per
/// probe = real handshake cost" invariant. ALPN-pinned to `h2` only:
/// servers that don't speak HTTP/2 will fail the handshake (mapped to
/// 526), which matches the project's "no HTTP/1.1" stance.
fn tls_config_h2() -> Arc<rustls::ClientConfig> {
let mut roots = rustls::RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut cfg = rustls::ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
cfg.alpn_protocols = vec![b"h2".to_vec()];
Arc::new(cfg)
}
/// Round a duration to whole milliseconds, but report a sub-ms phase as
/// 1 rather than 0. The Linux kernel routes traffic destined for the
/// host's own public IP via `lo`, so loopback TCP/handshake phases
/// genuinely take 200-500 microseconds, which `as_millis()` truncates to
/// zero. A `0` in the chart reads as "this phase didn't happen" rather
/// than "this phase was instant", so floor it at 1 ms when the phase did
/// run. Total is unaffected: any real probe takes well over 1 ms.
fn elapsed_ms_atleast1(d: std::time::Duration) -> i64 {
let ms = d.as_millis() as i64;
if ms > 0 || d.is_zero() {
ms
} else {
1
}
}
fn looks_like_ssl_error(e: &anyhow::Error) -> bool {
let s = format!("{e:?}").to_lowercase();
s.contains("certificate")
|| s.contains("invalidcertificate")
|| s.contains("tls")
|| s.contains("handshake")
}
/// Run a single HTTP check and persist the result. Maps SSL errors to 526
/// (Cloudflare convention) and timeouts to 408 so the dashboard can show
/// failure reasons without piping arbitrary error messages.
pub async fn run_check(pool: &SqlitePool, prop: &PropertyRow) -> sqlx::Result<i64> {
let outcome = match probe_with_redirects(&prop.url).await {
Ok(o) => o,
Err(e) => {
let code = if looks_like_ssl_error(&e) { 526 } else { 408 };
ProbeOutcome {
status_code: code,
headers_json: "{}".to_string(),
timings: PhaseTimings {
total_ms: (HTTP_TIMEOUT_SECS as i64) * 1000,
..PhaseTimings::default()
},
}
}
};
let id = prop.id.clone();
sqlx::query(
"INSERT INTO checks (property_id, status_code, response_ms, headers, dns_ms, tcp_ms, tls_ms, ttfb_ms, created_at) \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
)
.bind(&id)
.bind(outcome.status_code)
.bind(outcome.timings.total_ms)
.bind(&outcome.headers_json)
.bind(outcome.timings.dns_ms)
.bind(outcome.timings.tcp_ms)
.bind(outcome.timings.tls_ms)
.bind(outcome.timings.ttfb_ms)
.bind(now_ms())
.execute(pool)
.await?;
Ok(outcome.status_code)
}
/// First-hop with full phase timings, then follow up to MAX_REDIRECTS
/// 3xx hops to discover the final status code (so the alert state machine
/// keeps working when a property uses an http→https or apex→www
/// redirect). Phase timings always reflect the first hop only — that's
/// the latency a fresh visitor pays before being redirected, and it's the
/// only number that's meaningful when later hops live on different
/// servers/domains.
async fn probe_with_redirects(url_str: &str) -> anyhow::Result<ProbeOutcome> {
let url = Url::parse(url_str).context("invalid URL")?;
let outer = tokio::time::timeout(
Duration::from_secs(HTTP_TIMEOUT_SECS),
async {
let mut current = url.clone();
let first = phased_hop(¤t).await?;
let first_timings = first.timings;
let mut status = first.status_code;
let mut headers_json = first.raw_headers_json;
let mut headers = first.headers;
let mut hops = 0usize;
while is_redirect(status) && hops < MAX_REDIRECTS {
let Some(loc) = headers.get("location").cloned() else { break };
let Ok(next) = current.join(&loc) else { break };
current = next;
let hop = match phased_hop(¤t).await {
Ok(h) => h,
Err(_) => break,
};
status = hop.status_code;
headers_json = hop.raw_headers_json;
headers = hop.headers;
hops += 1;
}
Ok::<_, anyhow::Error>(ProbeOutcome {
status_code: status,
headers_json,
timings: first_timings,
})
},
)
.await;
match outer {
Ok(Ok(o)) => Ok(o),
Ok(Err(e)) => Err(e),
Err(_) => Err(anyhow!("timeout after {HTTP_TIMEOUT_SECS}s")),
}
}
fn is_redirect(code: i64) -> bool {
matches!(code, 301 | 302 | 303 | 307 | 308)
}
async fn phased_hop(url: &Url) -> anyhow::Result<HopResult> {
let host = url.host_str().context("URL missing host")?.to_string();
let port = url.port_or_known_default().context("URL missing port")?;
let path_q = match url.query() {
Some(q) if !q.is_empty() => format!("{}?{}", url.path(), q),
_ => url.path().to_string(),
};
let path_q = if path_q.is_empty() { "/".to_string() } else { path_q };
let is_https = url.scheme() == "https";
let total_start = Instant::now();
let dns_start = Instant::now();
let resolver =
TokioAsyncResolver::tokio_from_system_conf().context("creating dns resolver")?;
let lookup = resolver.lookup_ip(host.as_str()).await.context("dns lookup")?;
let ip = lookup
.iter()
.next()
.ok_or_else(|| anyhow!("no addresses for {host}"))?;
let dns_ms = elapsed_ms_atleast1(dns_start.elapsed());
let addr = SocketAddr::new(ip, port);
let tcp_start = Instant::now();
let tcp = TcpStream::connect(addr).await.context("tcp connect")?;
tcp.set_nodelay(true).ok();
let tcp_ms = elapsed_ms_atleast1(tcp_start.elapsed());
if !is_https {
// h2 over plain TCP (h2c with prior knowledge) is rare in the
// wild, and the project is HTTP/2-only, so reject http:// URLs
// explicitly rather than silently downgrading.
return Err(anyhow!("plain HTTP not supported; use https:// (HTTP/2 only)"));
}
let tls_start = Instant::now();
let server_name =
ServerName::try_from(host.clone()).context("invalid TLS server name")?;
let connector = TlsConnector::from(tls_config_h2());
let tls_stream = connector
.connect(server_name, tcp)
.await
.context("tls handshake")?;
let tls_ms = elapsed_ms_atleast1(tls_start.elapsed());
let (status, headers, raw_headers_json, ttfb_ms) =
h2_request(tls_stream, &host, &path_q).await?;
let total_ms = total_start.elapsed().as_millis() as i64;
Ok(HopResult {
status_code: status,
headers,
raw_headers_json,
timings: PhaseTimings {
dns_ms: Some(dns_ms),
tcp_ms: Some(tcp_ms),
tls_ms: Some(tls_ms),
ttfb_ms: Some(ttfb_ms),
total_ms,
},
})
}
/// Run an HTTP/2 GET over an established TLS stream and return
/// (status_code, headers, headers_json, ttfb_ms). TTFB is measured from
/// the start of the h2 client handshake (SETTINGS exchange) to the
/// arrival of the response HEADERS frame, so it includes h2 protocol
/// setup; the user-facing chart treats it as "everything between secure
/// connection ready and first server byte", which matches curl's
/// `time_starttransfer` minus `time_appconnect`.
async fn h2_request(
tls_stream: tokio_rustls::client::TlsStream<TcpStream>,
host: &str,
path: &str,
) -> anyhow::Result<(i64, BTreeMap<String, String>, String, i64)> {
let ttfb_start = Instant::now();
let (sr, connection) = h2::client::handshake(tls_stream)
.await
.context("h2 handshake")?;
// h2 needs someone to drive the connection's I/O loop. Spawn a task
// that lives just as long as this probe; we abort it on the way out.
let conn_task = tokio::spawn(async move {
let _ = connection.await;
});
let mut sr = sr.ready().await.context("h2 send-request ready")?;
let req = http::Request::builder()
.method("GET")
.uri(format!("https://{host}{path}"))
.header("user-agent", USER_AGENT)
.header("accept", "*/*")
.body(())
.context("h2 request build")?;
let (rsp_fut, _send_stream) = sr.send_request(req, true).context("h2 send_request")?;
let rsp = rsp_fut.await.context("h2 response")?;
let ttfb_ms = elapsed_ms_atleast1(ttfb_start.elapsed());
let status = rsp.status().as_u16() as i64;
let mut headers = BTreeMap::new();
for (k, v) in rsp.headers().iter() {
if let Ok(s) = v.to_str() {
headers.insert(k.as_str().to_lowercase(), s.to_string());
}
}
let raw_headers_json = serde_json::Value::Object(
headers.iter().map(|(k, v)| (k.clone(), json!(v))).collect(),
)
.to_string();
drop(sr);
conn_task.abort();
Ok((status, headers, raw_headers_json, ttfb_ms))
}
/// Run a check, persist it, then advance the alert state machine and fire
/// notifications on transitions.
pub async fn process_check(
pool: &SqlitePool,
config: &crate::Config,
prop: &PropertyRow,
) -> anyhow::Result<()> {
let status_code = run_check(pool, prop).await?;
advance_alert_state(pool, config, prop, status_code).await?;
Ok(())
}
/// State transitions:
/// UP -> DOWN: requires 2 consecutive non-200 checks (avoids false positives)
/// DOWN -> UP: immediate on 200
/// Commits state inside a transaction *before* firing notifications so a
/// crash mid-alert can't cause duplicate sends.
async fn advance_alert_state(
pool: &SqlitePool,
config: &crate::Config,
prop: &PropertyRow,
status_code: i64,
) -> anyhow::Result<()> {
let is_up = status_code == 200;
let mut tx = pool.begin().await?;
let row: Option<(String,)> =
sqlx::query_as("SELECT alert_state FROM properties WHERE id = ?")
.bind(prop.id.clone())
.fetch_optional(&mut *tx)
.await?;
let Some((current_state,)) = row else {
return Ok(());
};
let mut transition: Option<&str> = None;
if is_up && current_state == "down" {
transition = Some("recovery");
} else if !is_up && current_state == "up" {
// Need 2 consecutive non-200s. We just inserted one; check the prior.
let recent: Vec<(i64,)> = sqlx::query_as(
"SELECT status_code FROM checks WHERE property_id = ? ORDER BY created_at DESC LIMIT 2",
)
.bind(prop.id.clone())
.fetch_all(&mut *tx)
.await?;
if recent.len() >= 2 && recent[0].0 != 200 && recent[1].0 != 200 {
transition = Some("down");
}
}
if let Some(kind) = transition {
let new_state = if kind == "recovery" { "up" } else { "down" };
sqlx::query(
"UPDATE properties SET alert_state = ?, last_alert_sent = ?, updated_at = ? WHERE id = ?",
)
.bind(new_state)
.bind(now_ms())
.bind(now_ms())
.bind(prop.id.clone())
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
if let Some(kind) = transition {
let avg_response_time = recent_avg_response_ms(pool, &prop.id).await.unwrap_or(0);
let ctx = alerts::EmailContext {
id: prop.uuid(),
name: prop.name(),
url: prop.url.clone(),
current_status: status_code,
avg_response_time,
};
let alert_email = config.alert_email.clone();
let webhook = config.discord_webhook_url.clone();
let base_url = config.base_url.clone();
let url_for_log = prop.url.clone();
// Fire-and-forget: alerts don't block the scheduler tick.
tokio::spawn(async move {
if let Err(e) =
alerts::fire(kind, &ctx, &base_url, alert_email.as_deref(), webhook.as_deref()).await
{
tracing::warn!("alert dispatch failed for {url_for_log}: {e}");
}
});
}
Ok(())
}
/// Average response time across the most recent 31 checks. Mirrors the
/// dashboard's "rolling avg" tile so the email matches what the user sees.
async fn recent_avg_response_ms(pool: &SqlitePool, id: &[u8]) -> sqlx::Result<i64> {
let rows: Vec<(i64,)> = sqlx::query_as(
"SELECT response_ms FROM checks WHERE property_id = ? ORDER BY created_at DESC LIMIT 31",
)
.bind(id.to_vec())
.fetch_all(pool)
.await?;
if rows.is_empty() {
return Ok(0);
}
let sum: i64 = rows.iter().map(|r| r.0).sum();
Ok(sum / rows.len() as i64)
}
/// Compute the next due time aligned to a 3-minute boundary, matching Django.
pub fn next_3min_boundary() -> i64 {
let now = chrono::Utc::now();
let minute = now.minute() as i64;
let aligned_min = (minute / 3) * 3;
let aligned = now
.with_minute(aligned_min as u32)
.and_then(|d| d.with_second(0))
.and_then(|d| d.with_nanosecond(0))
.unwrap_or(now);
(aligned + chrono::Duration::minutes(3)).timestamp_millis()
}
#[allow(dead_code)]
pub async fn property_id_to_uuid(id_blob: &[u8]) -> Uuid {
Uuid::from_slice(id_blob).unwrap_or(Uuid::nil())
}