diff --git a/services/api/Cargo.toml b/services/api/Cargo.toml index bab4d7e..2e8953f 100644 --- a/services/api/Cargo.toml +++ b/services/api/Cargo.toml @@ -25,7 +25,7 @@ async-trait = "0.1" reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid"] } +sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio-rustls", "postgres", "chrono", "uuid", "derive"] } tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "time", "sync"] } tokio-util = { version = "0.7", features = ["rt"] } tower = { version = "0.5", features = ["util"] } @@ -47,11 +47,12 @@ hex = "0.4" base64 = "0.22" subtle = "2.5" ipnet = "2" +fastrand = "2.4.1" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } -testcontainers = { version = "0.23", features = ["tokio"] } -testcontainers-modules = { version = "0.11", features = ["redis", "tokio"] } +testcontainers = { version = "0.23" } +testcontainers-modules = { version = "0.11", features = ["redis", "postgres"] } [[bench]] name = "api_key_auth" diff --git a/services/api/src/audit.rs b/services/api/src/audit.rs index e839301..9925efa 100644 --- a/services/api/src/audit.rs +++ b/services/api/src/audit.rs @@ -1,7 +1,6 @@ use std::net::IpAddr; -pub mod client_ip; -pub use client_ip::{extract_client_ip, trusted_cidrs_from_env}; +pub use crate::client_ip::{extract_client_ip, trusted_cidrs_from_env}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; diff --git a/services/api/src/audit_middleware.rs b/services/api/src/audit_middleware.rs index a0bea08..04c17ba 100644 --- a/services/api/src/audit_middleware.rs +++ b/services/api/src/audit_middleware.rs @@ -1,5 +1,4 @@ -pub mod body_redact; -pub use body_redact::{body_logging_enabled, redact_sensitive, truncate_body}; +pub use crate::body_redact::{body_logging_enabled, redact_sensitive, truncate_body}; use std::sync::Arc; diff --git a/services/api/src/blockchain.rs b/services/api/src/blockchain.rs index 0caaffa..d9297d9 100644 --- a/services/api/src/blockchain.rs +++ b/services/api/src/blockchain.rs @@ -793,7 +793,7 @@ impl BlockchainClient { total_events = all_events.len(), "fetch_events_since paginated" ); - self.metrics.observe_invalidation("events_pagination_pages", pages); + self.metrics.observe_invalidation("events_pagination_pages", pages as usize); } Ok(all_events) diff --git a/services/api/src/cache/mod.rs b/services/api/src/cache/mod.rs index 2600d86..abc0730 100644 --- a/services/api/src/cache/mod.rs +++ b/services/api/src/cache/mod.rs @@ -7,11 +7,8 @@ use std::{ time::{Duration, Instant}, }; -use redis::redis_module::RedisResult; - - use anyhow::Context; -use deadpool_redis::{Config as PoolConfig, Pool, Runtime}; +use deadpool_redis::{Config as PoolConfig, Pool}; use redis::AsyncCommands; use serde::{de::DeserializeOwned, Serialize}; @@ -245,8 +242,9 @@ impl RedisCache { } // Deterministically hash the tag so the metadata key is stable. + use std::hash::{Hash, Hasher}; let mut hasher = std::collections::hash_map::DefaultHasher::new(); - std::hash::Hash::hash(&tag, &mut hasher); + tag.cache_keys().join("|").hash(&mut hasher); let tag_hash = format!("{:x}", hasher.finish()); let zset_key = self.tag_cfg.tag_key(&tag_hash); @@ -293,18 +291,19 @@ impl RedisCache { "#, ); - let mut over_evicted: i64 = 0; + let script = std::sync::Arc::new(script); self.exec(|mut conn| { let zset_key = zset_key.clone(); let seq_key = seq_key.clone(); let keys = tag_keys.clone(); + let script = script.clone(); async move { let mut argv: Vec = Vec::with_capacity(2 + keys.len()); argv.push(tag_ttl_secs.to_string()); argv.push(cap.to_string()); argv.extend(keys); - over_evicted = script + let _: i64 = script .key(&zset_key) .key(&seq_key) .arg(tag_ttl_secs) @@ -316,7 +315,6 @@ impl RedisCache { }) .await?; - // Note: we don't need the evicted count for correctness. Ok(()) } @@ -399,11 +397,14 @@ impl RedisCache { T: DeserializeOwned, { let key = key.to_owned(); - self.exec(|mut conn| async move { - let val: Option = conn.get(&key).await?; - match val { - Some(raw) => Ok(Some(serde_json::from_str(&raw)?)), - None => Ok(None), + self.exec(|mut conn| { + let key = key.clone(); + async move { + let val: Option = conn.get(&key).await?; + match val { + Some(raw) => Ok(Some(serde_json::from_str(&raw)?)), + None => Ok(None), + } } }) .await @@ -429,9 +430,12 @@ impl RedisCache { pub async fn del(&self, key: &str) -> anyhow::Result<()> { let key = key.to_owned(); - self.exec(|mut conn| async move { - let _: usize = conn.del(&key).await?; - Ok(()) + self.exec(|mut conn| { + let key = key.clone(); + async move { + let _: usize = conn.del(&key).await?; + Ok(()) + } }) .await } @@ -460,23 +464,25 @@ impl RedisCache { let pattern = pattern.to_owned(); loop { - let pattern_clone = pattern.clone(); let (next_cursor, batch_deleted) = self - .exec(|mut conn| async move { - let (next_cursor, keys): (u64, Vec) = redis::cmd("SCAN") - .arg(cursor) - .arg("MATCH") - .arg(&pattern_clone) - .arg("COUNT") - .arg(100u64) - .query_async(&mut conn) - .await?; - let deleted = if keys.is_empty() { - 0 - } else { - conn.del(keys).await? - }; - Ok((next_cursor, deleted)) + .exec(|mut conn| { + let pattern_clone = pattern.clone(); + async move { + let (next_cursor, keys): (u64, Vec) = redis::cmd("SCAN") + .arg(cursor) + .arg("MATCH") + .arg(&pattern_clone) + .arg("COUNT") + .arg(100u64) + .query_async(&mut conn) + .await?; + let deleted = if keys.is_empty() { + 0 + } else { + conn.del(keys).await? + }; + Ok((next_cursor, deleted)) + } }) .await?; diff --git a/services/api/src/compression.rs b/services/api/src/compression.rs index da25788..0379118 100644 --- a/services/api/src/compression.rs +++ b/services/api/src/compression.rs @@ -1,48 +1,25 @@ -use tower_http::compression::predicate::{NotForContentType, Predicate}; +use axum::http::{header, Extensions, HeaderMap, StatusCode, Version}; use tower_http::compression::CompressionLayer; -fn should_compress_text_based(content_type: Option<&str>) -> bool { - let Some(ct) = content_type else { - // If we can't determine content type, avoid wasting CPU. - return false; - }; +type CompressFn = fn(StatusCode, Version, &HeaderMap, &Extensions) -> bool; - // Remove common parameters like `charset=utf-8`. +fn should_compress( + _: StatusCode, + _: Version, + headers: &HeaderMap, + _: &Extensions, +) -> bool { + let ct = headers + .get(header::CONTENT_TYPE) + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); let ct = ct.split(';').next().unwrap_or(ct).trim(); - - // Only compress text-ish payloads. - // Note: application/json is explicitly included. ct == "application/json" || ct.starts_with("text/") } -pub fn compression_layer() -> CompressionLayer { - // Exclude already-compressed/binary formats to avoid CPU waste. - // (This primarily protects against cases where `content_type` might be - // missing/incorrect while still keeping the middleware safe.) - let not_for_binary = NotForContentType::new(vec![ - "application/zip", - "application/gzip", - "application/x-gzip", - "application/x-zip-compressed", - "application/pdf", - "image/jpeg", - "image/png", - "image/webp", - "image/gif", - "image/svg+xml", - "audio/mpeg", - "audio/mp4", - "video/mp4", - "application/octet-stream", - "application/x-bzip2", - "application/x-7z-compressed", - ]); - +pub fn compression_layer() -> CompressionLayer { CompressionLayer::new() .gzip(true) .br(true) - // Only apply compression to text-based responses. - .compress_when(Predicate::from_fn(should_compress_text_based)) - .filter(not_for_binary) + .compress_when(should_compress as CompressFn) } - diff --git a/services/api/src/config.rs b/services/api/src/config.rs index decf596..0893578 100644 --- a/services/api/src/config.rs +++ b/services/api/src/config.rs @@ -66,7 +66,7 @@ impl CorsConfig { .filter(|s| !s.is_empty()) .collect() }) - .unwrap_or_else(|| { + .unwrap_or_else(|_| { ["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"] .iter() .map(|s| s.to_string()) @@ -80,7 +80,7 @@ impl CorsConfig { .filter(|s| !s.is_empty()) .collect() }) - .unwrap_or_else(|| { + .unwrap_or_else(|_| { ["content-type", "authorization"] .iter() .map(|s| s.to_string()) diff --git a/services/api/src/db.rs b/services/api/src/db.rs index d1e3540..023362e 100644 --- a/services/api/src/db.rs +++ b/services/api/src/db.rs @@ -124,7 +124,7 @@ impl Database { /// Snapshot pool size/idle into Prometheus gauges. /// Call this just before rendering `/metrics` so the values are current. pub fn record_pool_metrics(&self) { - self.metrics.record_pool_metrics(self.pool.size(), self.pool.num_idle()); + self.metrics.observe_pool_connections("primary", self.pool.size() as i64, self.pool.num_idle() as i64); } pub async fn new( @@ -760,7 +760,6 @@ impl Database { Ok(count > 0) } } -} #[cfg(test)] mod tests { diff --git a/services/api/src/email/service.rs b/services/api/src/email/service.rs index cbe3421..4af7a76 100644 --- a/services/api/src/email/service.rs +++ b/services/api/src/email/service.rs @@ -1,5 +1,4 @@ use anyhow::{Context, Result}; -use redis::AsyncCommands as _; use serde_json::Value; use sha2::{Digest, Sha256}; use std::time::Duration; @@ -139,7 +138,7 @@ impl EmailService { // --- idempotency check --- if let (Some(cache), Some(key)) = (&self.cache, idem_key) { let redis_key = format!("email:idem:{key}"); - let mut conn = cache.manager.clone(); + let mut conn = cache.get_connection().await.context("idempotency Redis connection failed")?; // Try SET NX — only succeeds for the first send. let acquired: Option = redis::cmd("SET") diff --git a/services/api/src/handlers.rs b/services/api/src/handlers.rs index bbea700..cfd03e3 100644 --- a/services/api/src/handlers.rs +++ b/services/api/src/handlers.rs @@ -1,4 +1,3 @@ -use crate::content_type::require_json_content_type; use std::{ sync::Arc, time::{Duration, Instant}, @@ -565,7 +564,7 @@ pub async fn statistics(State(state): State>) -> Result) -> anyhow::Result<()> { let (mut succeeded, mut failed) = (0usize, 0usize); - warm!("db.statistics", state.db.statistics_cached().map(|r| r.map(|_| ())), succeeded, failed); - warm!("db.featured_markets", state.db.featured_markets_cached(state.config.featured_limit).map(|r| r.map(|_| ())), succeeded, failed); - warm!("blockchain.health", state.blockchain.health_check_cached().map(|r| r.map(|_| ())), succeeded, failed); - warm!("blockchain.platform_stats", state.blockchain.platform_statistics_cached().map(|r| r.map(|_| ())), succeeded, failed); - warm!("api.statistics", statistics(State(state.clone())).map(|r| r.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}"))), succeeded, failed); - warm!("api.featured_markets", featured_markets(State(state.clone()), Query(PaginationQuery::default())).map(|r| r.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}"))), succeeded, failed); - warm!("api.content", content(State(state.clone()), Query(PaginationQuery::default())).map(|r| r.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}"))), succeeded, failed); + warm!("db.statistics", state.db.statistics_cached(), succeeded, failed); + warm!("db.featured_markets", state.db.featured_markets_cached(state.config.featured_limit), succeeded, failed); + warm!("blockchain.health", state.blockchain.health_check_cached(), succeeded, failed); + warm!("blockchain.platform_stats", state.blockchain.platform_statistics_cached(), succeeded, failed); + warm!("api.statistics", async { statistics(State(state.clone())).await.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}")) }, succeeded, failed); + warm!("api.featured_markets", async { featured_markets(State(state.clone()), Query(PaginationQuery::default())).await.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}")) }, succeeded, failed); + warm!("api.content", async { content(State(state.clone()), Query(PaginationQuery::default())).await.map(|_| ()).map_err(|e| anyhow::anyhow!("{e:?}")) }, succeeded, failed); tracing::info!(succeeded, failed, total = succeeded + failed, "cache warming complete"); Ok(()) diff --git a/services/api/src/lib.rs b/services/api/src/lib.rs index fbcc23b..701b3cb 100644 --- a/services/api/src/lib.rs +++ b/services/api/src/lib.rs @@ -1,5 +1,8 @@ pub mod audit; pub mod audit_middleware; +pub mod body_redact; +pub mod client_ip; +pub mod content_type; #[cfg(test)] mod resolve_market_tests; pub mod blockchain; diff --git a/services/api/src/main.rs b/services/api/src/main.rs index 115a596..89d3242 100644 --- a/services/api/src/main.rs +++ b/services/api/src/main.rs @@ -92,7 +92,7 @@ async fn main() -> anyhow::Result<()> { )?; // Validate required configuration before proceeding - config.validate()?; + config.validate().map_err(|e| anyhow::anyhow!("{e}"))?; let metrics = Metrics::new()?; let cache = RedisCache::new(&config.redis_url).await?; diff --git a/services/api/src/migrations.rs b/services/api/src/migrations.rs index 3fee312..90c59ae 100644 --- a/services/api/src/migrations.rs +++ b/services/api/src/migrations.rs @@ -15,7 +15,7 @@ use anyhow::{bail, Context}; use sha2::{Digest, Sha256}; use sqlx::PgPool; -use tracing::{info, warn}; +use tracing::info; /// A single migration file embedded at compile time. #[derive(Debug, Clone)] diff --git a/services/api/src/pagination.rs b/services/api/src/pagination.rs index 1e43466..200887d 100644 --- a/services/api/src/pagination.rs +++ b/services/api/src/pagination.rs @@ -77,6 +77,38 @@ pub fn validate_pagination(params: PaginationParams) -> Result, + pub cursor: Option, +} + +impl PaginationQuery { + pub fn limit(&self) -> u32 { + self.limit.unwrap_or(DEFAULT_LIMIT).min(MAX_PAGE_LIMIT) + } + + pub fn cursor(&self) -> Option { + self.cursor.clone() + } +} + +/// A single page of results returned by paginated endpoints. +#[derive(Debug, Serialize)] +pub struct PaginatedResponse { + pub items: Vec, + pub next_cursor: Option, + pub limit: u32, + pub has_more: bool, +} + +impl PaginatedResponse { + pub fn new(items: Vec, next_cursor: Option, limit: u32, has_more: bool) -> Self { + Self { items, next_cursor, limit, has_more } + } +} + #[axum::async_trait] impl axum::extract::FromRequestParts for ValidatedPaginationQuery where diff --git a/services/api/src/rate_limit.rs b/services/api/src/rate_limit.rs index 6a33277..2035348 100644 --- a/services/api/src/rate_limit.rs +++ b/services/api/src/rate_limit.rs @@ -19,7 +19,7 @@ use axum::{ response::{IntoResponse, Response}, Json, }; -use deadpool_redis::{Pool as RedisPool, redis::AsyncCommands}; +use deadpool_redis::Pool as RedisPool; use serde::Serialize; use std::time::{SystemTime, UNIX_EPOCH}; use std::sync::Arc; @@ -152,6 +152,57 @@ pub async fn rate_limit_middleware( } } +pub async fn newsletter_rate_limit_middleware( + State(state): State>, + headers: HeaderMap, + req: axum::extract::Request, + next: Next, +) -> Response { + let client_key = client_key_from_headers(&headers); + if !state + .newsletter_rate_limiter + .allow(&client_key, 10, std::time::Duration::from_secs(3600)) + .await + { + let body = RateLimitError { + error: "rate_limit_exceeded", + message: "Too many newsletter requests. Please try again later.".to_string(), + retry_after: 3600, + }; + return ( + StatusCode::TOO_MANY_REQUESTS, + [("Retry-After", "3600".to_string())], + Json(body), + ) + .into_response(); + } + next.run(req).await +} + +pub async fn admin_rate_limit_middleware( + State(limiter): State>, + headers: HeaderMap, + req: axum::extract::Request, + next: Next, +) -> Response { + let client_key = client_key_from_headers(&headers); + let config = crate::security::RateLimitConfig::new(50, std::time::Duration::from_secs(60)); + if !limiter.check(&client_key, &config).await { + let body = RateLimitError { + error: "rate_limit_exceeded", + message: "Too many admin requests. Please try again later.".to_string(), + retry_after: 60, + }; + return ( + StatusCode::TOO_MANY_REQUESTS, + [("Retry-After", "60".to_string())], + Json(body), + ) + .into_response(); + } + next.run(req).await +} + fn client_key_from_headers(headers: &HeaderMap) -> String { headers .get("x-forwarded-for") diff --git a/services/api/src/validation.rs b/services/api/src/validation.rs index 6eac517..c78bb5e 100644 --- a/services/api/src/validation.rs +++ b/services/api/src/validation.rs @@ -8,11 +8,53 @@ //! //! This is a defence-in-depth layer; the frontend MUST also escape output. -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::Json; +use axum::{ + body::Body, + http::{Request, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, + Json, +}; use serde::Serialize; +const MAX_REQUEST_BODY_BYTES: u64 = 1 * 1024 * 1024; // 1 MB + +#[derive(Serialize)] +struct RequestTooLargeError { + error: &'static str, + message: String, + max_bytes: u64, +} + +pub async fn content_type_validation_middleware(req: Request, next: Next) -> Response { + crate::content_type::require_json_content_type(req, next).await +} + +pub async fn request_size_validation_middleware(req: Request, next: Next) -> Response { + if let Some(content_length) = req + .headers() + .get(axum::http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + { + if content_length > MAX_REQUEST_BODY_BYTES { + let body = RequestTooLargeError { + error: "request_too_large", + message: format!( + "Request body exceeds the {MAX_REQUEST_BODY_BYTES}-byte limit." + ), + max_bytes: MAX_REQUEST_BODY_BYTES, + }; + return (StatusCode::PAYLOAD_TOO_LARGE, Json(body)).into_response(); + } + } + next.run(req).await +} + +pub async fn request_validation_middleware(req: Request, next: Next) -> Response { + request_size_validation_middleware(req, next).await +} + #[derive(Debug, Serialize)] pub struct ValidationError { pub error: &'static str,