diff --git a/Cargo.lock b/Cargo.lock index ff81719e6..0b9455553 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -369,6 +369,7 @@ dependencies = [ "jaq-core", "jaq-json", "jaq-std", + "js-sys", "md-5", "monty", "num-traits", @@ -378,9 +379,11 @@ dependencies = [ "rand 0.10.1", "regex", "reqwest", + "rexie", "russh", "rustls", "serde", + "serde-wasm-bindgen", "serde_json", "serial_test", "sha1", @@ -394,6 +397,10 @@ dependencies = [ "turso_core", "unit-prefix", "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "web-time", "zapcode-core", "zeroize", ] @@ -482,6 +489,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "bashkit-wasm-tests" +version = "0.9.0" +dependencies = [ + "async-trait", + "bashkit", + "getrandom 0.3.4", + "js-sys", + "wasm-bindgen-test", +] + [[package]] name = "bcrypt-pbkdf" version = "0.11.0" @@ -1792,9 +1810,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi 5.3.0", "wasip2", + "wasm-bindgen", ] [[package]] @@ -2170,6 +2190,21 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" +[[package]] +name = "idb" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6554f394e990a1af530a528a7fdcad6e01b29cb1b990f89df3ffd62cf15f7828" +dependencies = [ + "indexmap", + "js-sys", + "num-traits", + "thiserror 2.0.18", + "tokio", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "idna" version = "1.1.0" @@ -2718,6 +2753,16 @@ dependencies = [ "syn", ] +[[package]] +name = "minicov" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4869b6a491569605d66d3952bcdf03df789e5b536e5f0cf7758a7f08a55ae24d" +dependencies = [ + "cc", + "walkdir", +] + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -2930,6 +2975,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -4071,6 +4117,17 @@ dependencies = [ "web-sys", ] +[[package]] +name = "rexie" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "887466cfa8a12c08ee4b174998135cea8ff0fd84858627cd793e56535a045bc9" +dependencies = [ + "idb", + "thiserror 1.0.69", + "wasm-bindgen", +] + [[package]] name = "rfc6979" version = "0.5.0" @@ -4593,6 +4650,17 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-wasm-bindgen" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8302e169f0eddcc139c70f139d19d6467353af16f9fce27e8c30158036a1e16b" +dependencies = [ + "js-sys", + "serde", + "wasm-bindgen", +] + [[package]] name = "serde_core" version = "1.0.228" @@ -5873,6 +5941,45 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-bindgen-test" +version = "0.3.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af5ec93229ad9ccd0a545a516dec76dc276613f278f6a91aa6b463d5b33d42d0" +dependencies = [ + "async-trait", + "cast", + "js-sys", + "libm", + "minicov", + "nu-ansi-term", + "num-traits", + "oorandom", + "serde", + "serde_json", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-bindgen-test-macro", + "wasm-bindgen-test-shared", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c81b9fef827e575e0e54431736d1baa0d700315d8c62cfef1f61fa3aad0cbeb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "wasm-bindgen-test-shared" +version = "0.2.121" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4d8ae7ad5440360e9799dfd42857d126454a88441ddf72d288ef83fa47f527" + [[package]] name = "wasm-encoder" version = "0.244.0" @@ -5930,6 +6037,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-root-certs" version = "1.0.7" diff --git a/Cargo.toml b/Cargo.toml index 84aca2b70..3339e771e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,6 +52,8 @@ fancy-regex = "0.18" # Date/Time chrono = "0.4" +# Drop-in replacement for std::time::{SystemTime, Duration} on WASM targets +web-time = "1" # Compression flate2 = "1" diff --git a/crates/bashkit-wasm-tests/Cargo.toml b/crates/bashkit-wasm-tests/Cargo.toml new file mode 100644 index 000000000..ebfcb982f --- /dev/null +++ b/crates/bashkit-wasm-tests/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "bashkit-wasm-tests" +version = "0.9.0" +edition = "2024" +license = "MIT" +authors = ["siennathesane"] +description = "WASM smoke tests for bashkit" + +[dependencies] +bashkit = { path = "../bashkit", features = ["indexeddb", "http_client"] } +wasm-bindgen-test = "0.3" +getrandom = { version = "0.3", features = ["wasm_js"] } +js-sys = "0.3" +async-trait = "0.1" + +[lib] +crate-type = ["cdylib"] diff --git a/crates/bashkit-wasm-tests/src/lib.rs b/crates/bashkit-wasm-tests/src/lib.rs new file mode 100644 index 000000000..47cf39661 --- /dev/null +++ b/crates/bashkit-wasm-tests/src/lib.rs @@ -0,0 +1,717 @@ +//! WASM-specific smoke tests for bashkit. +//! +//! Run with: wasm-pack test --headless --chrome + +use wasm_bindgen_test::*; +wasm_bindgen_test_configure!(run_in_browser); + +#[wasm_bindgen_test] +fn system_time_now_does_not_panic() { + let _now = bashkit::time::SystemTime::now(); +} + +#[wasm_bindgen_test] +fn unix_epoch_is_before_now() { + let epoch = bashkit::time::UNIX_EPOCH; + let now = bashkit::time::SystemTime::now(); + let duration = now + .duration_since(epoch) + .expect("now should be after epoch"); + assert!(duration.as_secs() > 1_000_000_000, "expected year 2001+"); +} + +#[wasm_bindgen_test] +fn chrono_roundtrip_utc() { + let now = bashkit::time::SystemTime::now(); + let dt = bashkit::time::to_chrono_utc(now); + let back = bashkit::time::from_chrono(dt); + + let diff = now + .duration_since(back) + .unwrap_or_else(|e| e.duration()) + .as_millis(); + assert!(diff < 2, "roundtrip drift should be < 2ms, got {}ms", diff); +} + +#[wasm_bindgen_test] +fn duration_arithmetic() { + let a = bashkit::time::Duration::from_secs(10); + let b = bashkit::time::Duration::from_secs(5); + assert_eq!((a + b).as_secs(), 15); +} + +// --------------------------------------------------------------------------- +// IndexedDB filesystem tests +// --------------------------------------------------------------------------- + +use bashkit::{FileSystem, FsBackend, IndexedDbFs, PosixFs}; +use std::path::Path; +use std::sync::Arc; + +/// Unique DB name per test to avoid collisions. +fn db_name(test: &str) -> String { + format!("bashkit_test_{}", test) +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_write_and_read_file() { + let fs = IndexedDbFs::new(db_name("write_read")); + fs.write(Path::new("/tmp/test.txt"), b"hello world") + .await + .unwrap(); + let content = fs.read(Path::new("/tmp/test.txt")).await.unwrap(); + assert_eq!(content, b"hello world"); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_append_file() { + let fs = IndexedDbFs::new(db_name("append")); + fs.write(Path::new("/tmp/log.txt"), b"line1\n") + .await + .unwrap(); + fs.append(Path::new("/tmp/log.txt"), b"line2\n") + .await + .unwrap(); + let content = fs.read(Path::new("/tmp/log.txt")).await.unwrap(); + assert_eq!(content, b"line1\nline2\n"); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_mkdir_and_exists() { + let fs = IndexedDbFs::new(db_name("mkdir")); + fs.mkdir(Path::new("/data/nested"), true).await.unwrap(); + assert!(fs.exists(Path::new("/data")).await.unwrap()); + assert!(fs.exists(Path::new("/data/nested")).await.unwrap()); + assert!(!fs.exists(Path::new("/data/nested/missing")).await.unwrap()); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_read_dir() { + let fs = IndexedDbFs::new(db_name("read_dir")); + fs.mkdir(Path::new("/tmp/sub"), true).await.unwrap(); + fs.write(Path::new("/tmp/a.txt"), b"a").await.unwrap(); + fs.write(Path::new("/tmp/b.txt"), b"b").await.unwrap(); + + let entries = fs.read_dir(Path::new("/tmp")).await.unwrap(); + let mut names: Vec<_> = entries.iter().map(|e| e.name.clone()).collect(); + names.sort(); + assert_eq!(names, vec!["a.txt", "b.txt", "sub"]); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_remove_file() { + let fs = IndexedDbFs::new(db_name("remove_file")); + fs.write(Path::new("/tmp/del.txt"), b"x").await.unwrap(); + assert!(fs.exists(Path::new("/tmp/del.txt")).await.unwrap()); + fs.remove(Path::new("/tmp/del.txt"), false).await.unwrap(); + assert!(!fs.exists(Path::new("/tmp/del.txt")).await.unwrap()); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_remove_dir_recursive() { + let fs = IndexedDbFs::new(db_name("remove_dir")); + fs.mkdir(Path::new("/tmp/deep/nested"), true).await.unwrap(); + fs.write(Path::new("/tmp/deep/file.txt"), b"x") + .await + .unwrap(); + fs.remove(Path::new("/tmp/deep"), true).await.unwrap(); + assert!(!fs.exists(Path::new("/tmp/deep")).await.unwrap()); + assert!(!fs.exists(Path::new("/tmp/deep/nested")).await.unwrap()); + assert!(!fs.exists(Path::new("/tmp/deep/file.txt")).await.unwrap()); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_stat() { + let fs = IndexedDbFs::new(db_name("stat")); + fs.write(Path::new("/tmp/stats.txt"), b"12345") + .await + .unwrap(); + let meta = fs.stat(Path::new("/tmp/stats.txt")).await.unwrap(); + assert!(meta.file_type.is_file()); + assert_eq!(meta.size, 5); + assert_eq!(meta.mode, 0o644); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_rename_file() { + let fs = IndexedDbFs::new(db_name("rename_file")); + fs.write(Path::new("/tmp/old.txt"), b"data").await.unwrap(); + fs.rename(Path::new("/tmp/old.txt"), Path::new("/tmp/new.txt")) + .await + .unwrap(); + assert!(!fs.exists(Path::new("/tmp/old.txt")).await.unwrap()); + assert!(fs.exists(Path::new("/tmp/new.txt")).await.unwrap()); + assert_eq!(fs.read(Path::new("/tmp/new.txt")).await.unwrap(), b"data"); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_copy_file() { + let fs = IndexedDbFs::new(db_name("copy_file")); + fs.write(Path::new("/tmp/src.txt"), b"copy me") + .await + .unwrap(); + fs.copy(Path::new("/tmp/src.txt"), Path::new("/tmp/dst.txt")) + .await + .unwrap(); + assert_eq!( + fs.read(Path::new("/tmp/src.txt")).await.unwrap(), + b"copy me" + ); + assert_eq!( + fs.read(Path::new("/tmp/dst.txt")).await.unwrap(), + b"copy me" + ); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_symlink() { + let fs = IndexedDbFs::new(db_name("symlink")); + fs.symlink(Path::new("/tmp/target.txt"), Path::new("/tmp/link.txt")) + .await + .unwrap(); + let target = fs.read_link(Path::new("/tmp/link.txt")).await.unwrap(); + assert_eq!(target, Path::new("/tmp/target.txt")); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_chmod() { + let fs = IndexedDbFs::new(db_name("chmod")); + fs.write(Path::new("/tmp/perms.txt"), b"x").await.unwrap(); + fs.chmod(Path::new("/tmp/perms.txt"), 0o755).await.unwrap(); + let meta = fs.stat(Path::new("/tmp/perms.txt")).await.unwrap(); + assert_eq!(meta.mode, 0o755); +} + +#[wasm_bindgen_test] +async fn indexeddb_fs_posix_wrapper() { + let backend = IndexedDbFs::new(db_name("posix")); + let fs = Arc::new(PosixFs::new(backend)); + + // Create parent dir first — IndexedDB fs doesn't auto-create parents + fs.mkdir(Path::new("/tmp"), false).await.unwrap(); + + // POSIX semantics: write -> read roundtrip + fs.write_file(Path::new("/tmp/posix.txt"), b"posix") + .await + .unwrap(); + let content = fs.read_file(Path::new("/tmp/posix.txt")).await.unwrap(); + assert_eq!(content, b"posix"); + + // POSIX semantics: cannot write to a directory + fs.mkdir(Path::new("/tmp/dir"), false).await.unwrap(); + let result = fs.write_file(Path::new("/tmp/dir"), b"x").await; + assert!(result.is_err(), "writing to a directory should fail"); + + // verify_filesystem_requirements smoke test + bashkit::verify_filesystem_requirements(&*fs).await.unwrap(); +} + +// --------------------------------------------------------------------------- +// HTTP client tests +// --------------------------------------------------------------------------- + +use bashkit::{HttpClient, HttpHandler, HttpResponse, Method, NetworkAllowlist}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +struct MockHandler { + status: u16, + headers: Vec<(String, String)>, + body: Vec, +} + +#[async_trait::async_trait] +impl HttpHandler for MockHandler { + async fn request( + &self, + _method: &str, + _url: &str, + _body: Option<&[u8]>, + _headers: &[(String, String)], + ) -> Result { + Ok(HttpResponse { + status: self.status, + headers: self.headers.clone(), + body: self.body.clone(), + }) + } +} + +struct EchoHandler; + +#[async_trait::async_trait] +impl HttpHandler for EchoHandler { + async fn request( + &self, + method: &str, + url: &str, + body: Option<&[u8]>, + headers: &[(String, String)], + ) -> Result { + let mut body_out = Vec::new(); + body_out.extend_from_slice(format!("{} {}\n", method, url).as_bytes()); + for (k, v) in headers { + body_out.extend_from_slice(format!("{}:{}\n", k, v).as_bytes()); + } + if let Some(b) = body { + body_out.extend_from_slice(b); + } + Ok(HttpResponse { + status: 200, + headers: vec![("Content-Type".to_string(), "text/plain".to_string())], + body: body_out, + }) + } +} + +#[wasm_bindgen_test] +fn http_client_new() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + assert_eq!(client.max_response_bytes(), 10 * 1024 * 1024); +} + +#[wasm_bindgen_test] +fn http_client_with_timeout() { + let client = HttpClient::with_timeout(NetworkAllowlist::allow_all(), Duration::from_secs(60)); + assert_eq!(client.max_response_bytes(), 10 * 1024 * 1024); +} + +#[wasm_bindgen_test] +fn http_client_with_config() { + let client = + HttpClient::with_config(NetworkAllowlist::allow_all(), Duration::from_secs(5), 1024); + assert_eq!(client.max_response_bytes(), 1024); +} + +#[wasm_bindgen_test] +async fn http_blocked_by_empty_allowlist() { + let client = HttpClient::new(NetworkAllowlist::new()); + let result = client.get("https://example.com").await; + assert!(result.is_err()); + let msg = result.unwrap_err().to_string(); + assert!( + msg.contains("access denied"), + "expected access denied, got: {}", + msg + ); +} + +#[wasm_bindgen_test] +async fn http_allowed_by_allowlist() { + let mut client = HttpClient::new(NetworkAllowlist::new().allow("https://example.com")); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let result = client.get("https://example.com").await; + assert!(result.is_ok()); + let resp = result.unwrap(); + assert_eq!(resp.status, 200); + assert_eq!(resp.body, b"ok"); +} + +#[wasm_bindgen_test] +async fn http_blocked_by_allowlist() { + let mut client = HttpClient::new(NetworkAllowlist::new().allow("https://allowed.com")); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let result = client.get("https://blocked.com").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access denied")); +} + +#[wasm_bindgen_test] +async fn http_get_method() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client.get("https://example.com/path").await.unwrap(); + assert_eq!(result.status, 200); + let text = result.body_string(); + assert!(text.starts_with("GET https://example.com/path")); +} + +#[wasm_bindgen_test] +async fn http_post_method() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client + .post("https://example.com/path", Some(b"hello")) + .await + .unwrap(); + assert_eq!(result.status, 200); + let text = result.body_string(); + assert!(text.starts_with("POST https://example.com/path")); + assert!(text.ends_with("hello")); +} + +#[wasm_bindgen_test] +async fn http_put_method() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client + .put("https://example.com/path", Some(b"world")) + .await + .unwrap(); + assert_eq!(result.status, 200); + let text = result.body_string(); + assert!(text.starts_with("PUT https://example.com/path")); + assert!(text.ends_with("world")); +} + +#[wasm_bindgen_test] +async fn http_delete_method() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client.delete("https://example.com/path").await.unwrap(); + assert_eq!(result.status, 200); + assert!( + result + .body_string() + .starts_with("DELETE https://example.com/path") + ); +} + +#[wasm_bindgen_test] +async fn http_head_method() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 204, + headers: vec![("X-Test".to_string(), "1".to_string())], + body: vec![], + })); + let result = client.head("https://example.com/path").await.unwrap(); + assert_eq!(result.status, 204); + assert!(result.headers.iter().any(|(k, _)| k == "X-Test")); +} + +#[wasm_bindgen_test] +async fn http_request_with_headers() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client + .request_with_headers( + Method::Get, + "https://example.com", + None, + &[("Authorization".to_string(), "Bearer token".to_string())], + ) + .await + .unwrap(); + let text = result.body_string(); + assert!(text.contains("Authorization:Bearer token")); +} + +#[wasm_bindgen_test] +async fn http_request_with_timeout() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let result = client + .request_with_timeout(Method::Get, "https://example.com", None, &[], Some(5)) + .await; + assert!(result.is_ok()); +} + +#[wasm_bindgen_test] +async fn http_request_with_timeouts() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let result = client + .request_with_timeouts( + Method::Get, + "https://example.com", + None, + &[], + Some(5), + Some(2), + ) + .await; + assert!(result.is_ok()); +} + +#[wasm_bindgen_test] +async fn http_response_body_string() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"hello world".to_vec(), + })); + let result = client.get("https://example.com").await.unwrap(); + assert_eq!(result.body_string(), "hello world"); +} + +#[wasm_bindgen_test] +async fn http_response_is_success() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 201, + headers: vec![], + body: vec![], + })); + let result = client.get("https://example.com").await.unwrap(); + assert!(result.is_success()); +} + +#[wasm_bindgen_test] +async fn http_response_is_not_success() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 404, + headers: vec![], + body: b"not found".to_vec(), + })); + let result = client.get("https://example.com").await.unwrap(); + assert!(!result.is_success()); +} + +#[wasm_bindgen_test] +async fn http_max_response_bytes_enforced() { + let mut client = + HttpClient::with_config(NetworkAllowlist::allow_all(), Duration::from_secs(30), 4); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"too-large".to_vec(), + })); + let result = client.get("https://example.com").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("response too large") + ); +} + +#[wasm_bindgen_test] +async fn http_before_http_hook() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + client.set_before_http(vec![Box::new(move |mut event| { + count_clone.fetch_add(1, Ordering::SeqCst); + event + .headers + .push(("X-Hook".to_string(), "fired".to_string())); + bashkit::hooks::HookAction::Continue(event) + })]); + let result = client.get("https://example.com").await.unwrap(); + assert_eq!(count.load(Ordering::SeqCst), 1); + assert!(result.body_string().contains("X-Hook:fired")); +} + +#[wasm_bindgen_test] +async fn http_after_http_hook() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![("X-Original".to_string(), "1".to_string())], + body: b"ok".to_vec(), + })); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + client.set_after_http(vec![Box::new(move |event| { + count_clone.fetch_add(1, Ordering::SeqCst); + bashkit::hooks::HookAction::Continue(event) + })]); + let _ = client.get("https://example.com").await.unwrap(); + assert_eq!(count.load(Ordering::SeqCst), 1); +} + +#[wasm_bindgen_test] +async fn http_before_http_hook_can_cancel() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + client.set_before_http(vec![Box::new(|_event| { + bashkit::hooks::HookAction::Cancel("nope".to_string()) + })]); + let result = client.get("https://example.com").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("cancelled")); +} + +#[wasm_bindgen_test] +async fn http_get_blocks_private_ip() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("http://10.0.0.1/secret").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("private IP")); +} + +#[wasm_bindgen_test] +async fn http_get_blocks_loopback() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("http://127.0.0.1/").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("private IP")); +} + +#[wasm_bindgen_test] +async fn http_get_allows_public_via_handler() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let result = client.get("https://example.com/").await; + assert!(result.is_ok()); +} + +#[wasm_bindgen_test] +async fn http_get_rejects_no_host() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("file:///etc/passwd").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("no host")); +} + +#[wasm_bindgen_test] +async fn http_get_rejects_invalid_url() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("definitely::not::a::url").await; + assert!(result.is_err()); +} + +#[wasm_bindgen_test] +async fn http_request_with_headers_blocked_by_allowlist() { + let client = HttpClient::new(NetworkAllowlist::new()); + let result = client + .request_with_headers(Method::Get, "https://example.com", None, &[]) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access denied")); +} + +#[wasm_bindgen_test] +async fn http_request_with_timeout_blocked_by_allowlist() { + let client = HttpClient::new(NetworkAllowlist::new()); + let result = client + .request_with_timeout(Method::Get, "https://example.com", None, &[], Some(5)) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access denied")); +} + +#[wasm_bindgen_test] +async fn http_before_http_hook_cannot_bypass_allowlist() { + let mut client = HttpClient::new(NetworkAllowlist::new().allow("https://allowed.com")); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + client.set_before_http(vec![Box::new(|mut event| { + event.url = "https://blocked.com".to_string(); + bashkit::hooks::HookAction::Continue(event) + })]); + let result = client + .request_with_headers(Method::Get, "https://allowed.com", None, &[]) + .await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("access denied")); +} + +#[wasm_bindgen_test] +async fn http_empty_body_ok() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client.post("https://example.com", None).await.unwrap(); + assert_eq!(result.status, 200); +} + +#[wasm_bindgen_test] +async fn http_handler_receives_headers() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let result = client + .request_with_headers( + Method::Get, + "https://example.com", + None, + &[ + ("Accept".to_string(), "application/json".to_string()), + ("X-Custom".to_string(), "value".to_string()), + ], + ) + .await + .unwrap(); + let text = result.body_string(); + assert!(text.contains("Accept:application/json")); + assert!(text.contains("X-Custom:value")); +} + +#[wasm_bindgen_test] +async fn http_multiple_before_hooks_fire_in_order() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(EchoHandler)); + let order = Arc::new(AtomicUsize::new(0)); + let order1 = order.clone(); + let order2 = order.clone(); + client.set_before_http(vec![ + Box::new(move |event| { + order1.fetch_add(1, Ordering::SeqCst); + bashkit::hooks::HookAction::Continue(event) + }), + Box::new(move |event| { + order2.fetch_add(10, Ordering::SeqCst); + bashkit::hooks::HookAction::Continue(event) + }), + ]); + let _ = client.get("https://example.com").await.unwrap(); + assert_eq!(order.load(Ordering::SeqCst), 11); +} + +#[wasm_bindgen_test] +async fn http_multiple_after_hooks_fire_in_order() { + let mut client = HttpClient::new(NetworkAllowlist::allow_all()); + client.set_handler(Box::new(MockHandler { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + })); + let order = Arc::new(AtomicUsize::new(0)); + let order1 = order.clone(); + let order2 = order.clone(); + client.set_after_http(vec![ + Box::new(move |event| { + order1.fetch_add(1, Ordering::SeqCst); + bashkit::hooks::HookAction::Continue(event) + }), + Box::new(move |event| { + order2.fetch_add(10, Ordering::SeqCst); + bashkit::hooks::HookAction::Continue(event) + }), + ]); + let _ = client.get("https://example.com").await.unwrap(); + assert_eq!(order.load(Ordering::SeqCst), 11); +} + +#[wasm_bindgen_test] +async fn http_v6_loopback_blocked() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("http://[::1]/").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("private IP")); +} + +#[wasm_bindgen_test] +async fn http_v4_mapped_v6_blocked() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + let result = client.get("http://[::ffff:10.0.0.1]/").await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("private IP")); +} diff --git a/crates/bashkit/Cargo.toml b/crates/bashkit/Cargo.toml index ad117c13a..7a694ae92 100644 --- a/crates/bashkit/Cargo.toml +++ b/crates/bashkit/Cargo.toml @@ -35,6 +35,7 @@ fancy-regex = { workspace = true } # Date/Time chrono = { workspace = true } +web-time = { workspace = true } # HTTP client (for curl/wget) - optional, enabled with http_client feature reqwest = { workspace = true, optional = true } @@ -107,12 +108,22 @@ zapcode-core = { version = "1.5", optional = true } # Pulls a multi-MB transitive dep tree only when the `sqlite` feature is active. turso_core = { workspace = true, optional = true } +# IndexedDB backend for wasm32 browser persistence (optional) +rexie = { version = "0.6", optional = true } +wasm-bindgen = { version = "0.2", optional = true } +serde-wasm-bindgen = { version = "0.6", optional = true } + +# WASM HTTP client dependencies (browser fetch API) +js-sys = { version = "0.3", optional = true } +wasm-bindgen-futures = { version = "0.4", optional = true } +web-sys = { version = "0.3", optional = true, features = ["Headers", "Request", "RequestInit", "RequestMode", "Response", "Window", "AbortController", "AbortSignal"] } + [features] default = [] # Enable jq builtin via embedded jaq interpreter # Usage: cargo build --features jq jq = ["dep:jaq-core", "dep:jaq-std", "dep:jaq-json"] -http_client = ["reqwest", "rustls"] +http_client = [] # Enable Ed25519 request signing per RFC 9421 / web-bot-auth profile bot-auth = ["http_client", "dep:ed25519-dalek", "dep:rand", "dep:zeroize"] # Enable fail points for security/fault injection testing @@ -155,6 +166,9 @@ sqlite = ["dep:turso_core", "tokio/rt-multi-thread"] realfs = [] # Enable native-extension interop contracts such as bashkit::interop::fs. interop = ["tokio/rt-multi-thread"] +# Enable IndexedDB backend for wasm32 browser persistence. +# Usage: cargo build --target wasm32-unknown-unknown --features indexeddb +indexeddb = ["dep:rexie", "dep:wasm-bindgen", "dep:serde-wasm-bindgen"] [package.metadata.docs.rs] all-features = true @@ -246,6 +260,17 @@ required-features = ["sqlite"] name = "sqlite_workflow" required-features = ["sqlite"] -# Additional tokio features needed only on native (not WASM) -[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +# Native-only: reqwest + rustls are always available on non-WASM targets +[target.'cfg(not(target_family = "wasm"))'.dependencies] tokio = { version = "1", features = ["rt-multi-thread", "fs"] } +reqwest = { workspace = true } +rustls = { workspace = true } + +# WASM-only: browser fetch API deps always available on WASM targets +[target.'cfg(target_family = "wasm")'.dependencies] +js-sys = "0.3" +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" +web-sys = { version = "0.3", features = ["Headers", "Request", "RequestInit", "RequestMode", "Response", "Window", "AbortController", "AbortSignal"] } + + diff --git a/crates/bashkit/src/builtins/archive.rs b/crates/bashkit/src/builtins/archive.rs index e020287ed..d3f7f075d 100644 --- a/crates/bashkit/src/builtins/archive.rs +++ b/crates/bashkit/src/builtins/archive.rs @@ -332,7 +332,7 @@ async fn add_file_to_tar( // Mtime (12 bytes, octal) let mtime = metadata .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); write_octal(&mut header[136..148], mtime, 11); @@ -403,7 +403,7 @@ fn add_directory_to_tar<'a>( // Mtime let mtime = metadata .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); write_octal(&mut header[136..148], mtime, 11); diff --git a/crates/bashkit/src/builtins/curl.rs b/crates/bashkit/src/builtins/curl.rs index b5f439548..03ab0cbb8 100644 --- a/crates/bashkit/src/builtins/curl.rs +++ b/crates/bashkit/src/builtins/curl.rs @@ -513,8 +513,8 @@ async fn execute_curl_request( let multipart_body: Option> = if !form_fields.is_empty() { let boundary = format!( "----bashkit{:016x}", - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) + crate::time::SystemTime::now() + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_nanos()) .unwrap_or(0) ); @@ -1740,7 +1740,7 @@ mod tests { async fn set_modified_time( &self, path: &std::path::Path, - time: std::time::SystemTime, + time: crate::time::SystemTime, ) -> Result<()> { self.inner.set_modified_time(path, time).await } diff --git a/crates/bashkit/src/builtins/date.rs b/crates/bashkit/src/builtins/date.rs index c9c348516..bde821d58 100644 --- a/crates/bashkit/src/builtins/date.rs +++ b/crates/bashkit/src/builtins/date.rs @@ -439,7 +439,7 @@ impl Builtin for Date { let path = resolve_path(ctx.cwd, file); match ctx.fs.stat(&path).await { Ok(meta) => { - dt_utc = meta.modified.into(); + dt_utc = crate::time::to_chrono_utc(meta.modified); epoch_input = false; } Err(_) => { diff --git a/crates/bashkit/src/builtins/fileops.rs b/crates/bashkit/src/builtins/fileops.rs index 59eb6d813..6028554f6 100644 --- a/crates/bashkit/src/builtins/fileops.rs +++ b/crates/bashkit/src/builtins/fileops.rs @@ -2,10 +2,10 @@ // Decision: touch delegates mtime changes to the filesystem layer so `touch` // and `touch -t` stay consistent across in-memory, overlay, and realfs backends. +use crate::time::SystemTime; use async_trait::async_trait; use chrono::{Datelike, Local, LocalResult, NaiveDate, TimeZone}; use std::path::Path; -use std::time::SystemTime; use super::limits::MKTEMP_MAX_ATTEMPTS; use super::{Builtin, Context, resolve_path}; @@ -376,7 +376,7 @@ fn parse_touch_timestamp(raw: &str) -> std::result::Result { LocalResult::None => return Err(format!("touch: invalid date format '{}'\n", raw)), }; - Ok(local.into()) + Ok(crate::time::from_chrono(local)) } #[async_trait] @@ -1126,7 +1126,8 @@ mod tests { assert_eq!(result.exit_code, 0); let metadata = fs.stat(&file).await.unwrap(); - let modified: DateTime = metadata.modified.into(); + let modified: DateTime = + crate::time::to_chrono_utc(metadata.modified).with_timezone(&Local); assert_eq!(modified.year(), 2026); assert_eq!(modified.month(), 4); assert_eq!(modified.day(), 6); diff --git a/crates/bashkit/src/builtins/inspect.rs b/crates/bashkit/src/builtins/inspect.rs index 9b871a76c..af01243f3 100644 --- a/crates/bashkit/src/builtins/inspect.rs +++ b/crates/bashkit/src/builtins/inspect.rs @@ -417,7 +417,7 @@ fn format_permissions(metadata: &crate::fs::Metadata) -> String { fn default_stat_format(name: &str, metadata: &crate::fs::Metadata) -> String { let modified = metadata .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); diff --git a/crates/bashkit/src/builtins/ls/find.rs b/crates/bashkit/src/builtins/ls/find.rs index 725df468f..66a72b700 100644 --- a/crates/bashkit/src/builtins/ls/find.rs +++ b/crates/bashkit/src/builtins/ls/find.rs @@ -611,7 +611,7 @@ fn find_printf_format(fmt: &str, display_path: &str, metadata: &crate::fs::Metad if i < chars.len() && chars[i] == '@' { let secs = metadata .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .ok() .map(|d| d.as_secs()) .unwrap_or(0); diff --git a/crates/bashkit/src/builtins/ls/list.rs b/crates/bashkit/src/builtins/ls/list.rs index f06804b21..87fca18be 100644 --- a/crates/bashkit/src/builtins/ls/list.rs +++ b/crates/bashkit/src/builtins/ls/list.rs @@ -441,7 +441,7 @@ pub(super) fn format_long_entry(name: &str, metadata: &crate::fs::Metadata, huma // Format modified time let modified = metadata .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .map(|d| { let secs = d.as_secs(); // Simple date formatting: YYYY-MM-DD HH:MM diff --git a/crates/bashkit/src/builtins/mod.rs b/crates/bashkit/src/builtins/mod.rs index 7069a5626..889dd0b4c 100644 --- a/crates/bashkit/src/builtins/mod.rs +++ b/crates/bashkit/src/builtins/mod.rs @@ -394,7 +394,7 @@ pub struct ExecutionExtensions { #[derive(Debug, Clone)] pub(crate) struct ExecutionDeadline { #[allow(dead_code)] - started_at: std::time::Instant, + started_at: crate::time::Instant, timeout: std::time::Duration, } @@ -402,7 +402,7 @@ impl ExecutionDeadline { /// Create a deadline anchored at the start of `Bash::exec*`. pub(crate) fn new(timeout: std::time::Duration) -> Self { Self { - started_at: std::time::Instant::now(), + started_at: crate::time::Instant::now(), timeout, } } diff --git a/crates/bashkit/src/builtins/python.rs b/crates/bashkit/src/builtins/python.rs index 90d561159..27f594c4b 100644 --- a/crates/bashkit/src/builtins/python.rs +++ b/crates/bashkit/src/builtins/python.rs @@ -18,6 +18,7 @@ //! //! Supports: `python -c "code"`, `python script.py`, stdin piping. +use crate::time::Duration; use async_trait::async_trait; use chrono::{Datelike, Timelike}; use monty::{ @@ -30,7 +31,6 @@ use std::future::Future; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; -use std::time::Duration; use super::{Builtin, Context, resolve_path}; use crate::error::Result; @@ -794,7 +794,7 @@ async fn handle_os_call( Ok(meta) => { let mtime = meta .modified - .duration_since(std::time::UNIX_EPOCH) + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_secs_f64()) .unwrap_or(0.0); let stat_obj = match meta.file_type { diff --git a/crates/bashkit/src/builtins/rg/mod.rs b/crates/bashkit/src/builtins/rg/mod.rs index 86b4d04ca..83ff3dbed 100644 --- a/crates/bashkit/src/builtins/rg/mod.rs +++ b/crates/bashkit/src/builtins/rg/mod.rs @@ -7035,7 +7035,7 @@ mod tests { fs_trait.write_file(p, content).await.unwrap(); } for (path, secs) in mtimes { - let time = std::time::UNIX_EPOCH + std::time::Duration::from_secs(*secs); + let time = crate::time::UNIX_EPOCH + crate::time::Duration::from_secs(*secs); fs_trait .set_modified_time(Path::new(path), time) .await @@ -12490,7 +12490,7 @@ mod tests { } for (path, secs) in case.mtimes { let host_path = tempdir.path().join(path.trim_start_matches('/')); - let time = std::time::UNIX_EPOCH + std::time::Duration::from_secs(*secs); + let time = crate::time::UNIX_EPOCH + crate::time::Duration::from_secs(*secs); std::fs::File::open(host_path) .expect("open timed rg fixture file") .set_modified(time) diff --git a/crates/bashkit/src/builtins/shuf.rs b/crates/bashkit/src/builtins/shuf.rs index ad3959d10..75c39fc7c 100644 --- a/crates/bashkit/src/builtins/shuf.rs +++ b/crates/bashkit/src/builtins/shuf.rs @@ -364,8 +364,8 @@ struct SmallRng { impl SmallRng { fn seed_from_now() -> Self { - let nanos = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) + let nanos = crate::time::SystemTime::now() + .duration_since(crate::time::UNIX_EPOCH) .map(|d| d.as_nanos() as u64) .unwrap_or(0x123_4567_89AB_CDEF); // SystemTime::now can return a tiny duration in tight loops; XOR diff --git a/crates/bashkit/src/builtins/sleep.rs b/crates/bashkit/src/builtins/sleep.rs index fa03a2be7..d4315cf81 100644 --- a/crates/bashkit/src/builtins/sleep.rs +++ b/crates/bashkit/src/builtins/sleep.rs @@ -95,7 +95,7 @@ mod tests { #[tokio::test] async fn test_sleep_zero() { - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let result = run_sleep(&["0"]).await; let elapsed = start.elapsed(); @@ -105,7 +105,7 @@ mod tests { #[tokio::test] async fn test_sleep_fractional() { - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let result = run_sleep(&["0.1"]).await; let elapsed = start.elapsed(); diff --git a/crates/bashkit/src/builtins/sqlite/engine.rs b/crates/bashkit/src/builtins/sqlite/engine.rs index 32882bcc3..275f28682 100644 --- a/crates/bashkit/src/builtins/sqlite/engine.rs +++ b/crates/bashkit/src/builtins/sqlite/engine.rs @@ -15,9 +15,9 @@ //! Both expose the same query API. The builtin layer above is agnostic to //! which backend is active. +use crate::time::Instant; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::Instant; use turso_core::{Connection, Database, IO, MemoryIO, Numeric, OpenFlags, StepResult, Value}; diff --git a/crates/bashkit/src/fs/backend.rs b/crates/bashkit/src/fs/backend.rs index 4d8c25782..953c037dd 100644 --- a/crates/bashkit/src/fs/backend.rs +++ b/crates/bashkit/src/fs/backend.rs @@ -92,9 +92,9 @@ //! //! See `examples/custom_backend.rs` for a complete working example. +use crate::time::SystemTime; use async_trait::async_trait; use std::path::{Path, PathBuf}; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use super::traits::{DirEntry, Metadata}; diff --git a/crates/bashkit/src/fs/indexeddb.rs b/crates/bashkit/src/fs/indexeddb.rs new file mode 100644 index 000000000..5e4dbcad8 --- /dev/null +++ b/crates/bashkit/src/fs/indexeddb.rs @@ -0,0 +1,845 @@ +//! IndexedDB filesystem backend for wasm32. +//! +//! [`IndexedDbFs`] implements [`FsBackend`] using the browser's IndexedDB API +//! via the `rexie` crate. It persists files and directories across page reloads +//! in browser environments. +//! +//! # Usage +//! +//! ```rust,ignore +//! use bashkit::{FsBackend, PosixFs, IndexedDbFs}; +//! use std::sync::Arc; +//! +//! let backend = IndexedDbFs::new("bashkit_fs"); +//! let fs = Arc::new(PosixFs::new(backend)); +//! ``` +//! +//! # Safety +//! +//! This module uses [`AssertSend`] to wrap futures that contain `wasm_bindgen` +//! closure types. On `wasm32-unknown-unknown` there is only a single thread, so +//! asserting `Send` is sound. The module is gated to `wasm32` via `cfg` when the +//! `indexeddb` feature is enabled. + +use crate::time::{Duration, SystemTime, UNIX_EPOCH}; +use async_trait::async_trait; +use rexie::{ObjectStore, Rexie, TransactionMode}; +use serde::{Deserialize, Serialize}; +use std::future::Future; +use std::io::{Error as IoError, ErrorKind}; +use std::path::{Path, PathBuf}; +use std::pin::Pin; +use std::task::{Context, Poll}; +use wasm_bindgen::JsValue; + +use super::backend::FsBackend; +use super::normalize_path; +use super::traits::{DirEntry, Metadata}; +use crate::error::Result; + +const STORE_NAME: &str = "entries"; + +/// Wrapper that asserts a future is `Send`. +/// +/// # Safety +/// +/// On `wasm32-unknown-unknown` there is only one thread, so all types are +/// effectively `Send`. This wrapper is only used within the IndexedDB backend +/// which is compiled exclusively for that target. +struct AssertSend(F); + +unsafe impl Send for AssertSend {} +unsafe impl Sync for AssertSend {} + +impl Future for AssertSend { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: We are projecting from Pin<&mut AssertSend> to Pin<&mut F>. + // AssertSend is a newtype wrapper with the same memory layout. + unsafe { self.map_unchecked_mut(|s| &mut s.0).poll(cx) } + } +} + +/// Wrap a future so that it satisfies `Send` bounds. +/// +/// This is a synchronous constructor — it immediately wraps `f` in +/// [`AssertSend`] and returns it, so the caller's async generator never +/// holds the unwrapped `f` across an await point. +fn run(f: F) -> AssertSend { + AssertSend(f) +} + +/// Stored representation of a filesystem entry in IndexedDB. +#[derive(Clone, Debug, Serialize, Deserialize)] +struct DbEntry { + path: String, + kind: DbEntryKind, + content: Option>, + mode: u32, + modified: f64, + created: f64, + target: Option, + size: u64, +} + +/// Kind of filesystem entry. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +enum DbEntryKind { + File, + Directory, + Symlink, +} + +/// IndexedDB filesystem backend. +/// +/// Stores files, directories, and symlinks in the browser's IndexedDB. +/// Each operation opens the database, performs the work, and closes it. +/// This avoids `Send`/`Sync` issues with `rexie`'s internal closure types. +#[derive(Clone, Debug)] +pub struct IndexedDbFs { + db_name: String, +} + +impl IndexedDbFs { + /// Create a new IndexedDB filesystem with the given database name. + pub fn new(db_name: impl Into) -> Self { + Self { + db_name: db_name.into(), + } + } + + fn now_ms() -> f64 { + let now = SystemTime::now(); + let dur = now.duration_since(UNIX_EPOCH).unwrap_or_default(); + dur.as_millis() as f64 + } + + fn system_time_to_ms(time: SystemTime) -> f64 { + let dur = time.duration_since(UNIX_EPOCH).unwrap_or_default(); + dur.as_millis() as f64 + } + + fn ms_to_system_time(ms: f64) -> SystemTime { + UNIX_EPOCH + Duration::from_millis(ms.max(0.0) as u64) + } + + fn entry_to_metadata(entry: &DbEntry) -> Metadata { + use super::traits::FileType; + let file_type = match entry.kind { + DbEntryKind::File => FileType::File, + DbEntryKind::Directory => FileType::Directory, + DbEntryKind::Symlink => FileType::Symlink, + }; + Metadata { + file_type, + size: entry.size, + mode: entry.mode, + modified: Self::ms_to_system_time(entry.modified), + created: Self::ms_to_system_time(entry.created), + } + } + + fn is_direct_child(parent: &Path, child_path: &str) -> Option { + let parent_str = parent.to_str()?; + let child = Path::new(child_path); + if child.parent()?.as_os_str() == parent_str { + child.file_name()?.to_str().map(|s| s.to_string()) + } else { + None + } + } + + /// Open the IndexedDB and ensure the root directory `/` exists. + async fn open_db(db_name: &str) -> Result { + let db = Rexie::builder(db_name) + .version(1) + .add_object_store(ObjectStore::new(STORE_NAME).key_path("path")) + .build() + .await + .map_err(|e| IoError::other(format!("indexeddb open: {e}")))?; + + // Invisible root node — ensure `/` always exists + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let root_key: JsValue = "/".into(); + if store + .get(root_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .is_none() + { + let root = DbEntry { + path: "/".to_string(), + kind: DbEntryKind::Directory, + content: None, + mode: 0o755, + modified: Self::now_ms(), + created: Self::now_ms(), + target: None, + size: 0, + }; + let js = serde_wasm_bindgen::to_value(&root) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + } + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb commit: {e}")))?; + Ok(db) + } +} + +fn path_to_js(path: &Path) -> std::io::Result { + path.to_str() + .ok_or_else(|| IoError::other("non-UTF-8 path")) + .map(|s| s.into()) +} + +fn path_to_string(path: &Path) -> std::io::Result { + path.to_str() + .ok_or_else(|| IoError::other("non-UTF-8 path")) + .map(|s| s.to_string()) +} + +#[async_trait] +impl FsBackend for IndexedDbFs { + async fn read(&self, path: &Path) -> Result> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadOnly) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let js_value = store + .get(js_key) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + + Ok(entry.content.unwrap_or_default()) + }) + .await + } + + async fn write(&self, path: &Path, content: &[u8]) -> Result<()> { + let path = normalize_path(path); + let content = content.to_vec(); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let now = Self::now_ms(); + + let content_len = content.len() as u64; + let entry = if let Some(js_value) = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + { + let mut existing: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + existing.content = Some(content); + existing.modified = now; + existing.size = content_len; + existing + } else { + DbEntry { + path: path_to_string(&path)?, + kind: DbEntryKind::File, + content: Some(content), + mode: 0o644, + modified: now, + created: now, + target: None, + size: content_len, + } + }; + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .put(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb put: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn append(&self, path: &Path, content: &[u8]) -> Result<()> { + let path = normalize_path(path); + let content = content.to_vec(); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let mut existing_content = Vec::new(); + let mut mode = 0o644; + let mut created = Self::now_ms(); + + if let Some(js_value) = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + { + let existing: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + if let Some(content) = existing.content { + existing_content = content; + } + mode = existing.mode; + created = existing.created; + } + + existing_content.extend_from_slice(&content); + + let entry = DbEntry { + path: path_to_string(&path)?, + kind: DbEntryKind::File, + content: Some(existing_content.clone()), + mode, + modified: Self::now_ms(), + created, + target: None, + size: existing_content.len() as u64, + }; + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .put(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb put: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn mkdir(&self, path: &Path, recursive: bool) -> Result<()> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let now = Self::now_ms(); + + if recursive { + let mut current = PathBuf::from("/"); + for component in path.components().skip(1) { + current.push(component); + let js_key: JsValue = path_to_js(¤t)?; + if store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .is_none() + { + let entry = DbEntry { + path: path_to_string(¤t)?, + kind: DbEntryKind::Directory, + content: None, + mode: 0o755, + modified: now, + created: now, + target: None, + size: 0, + }; + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + } + } + } else { + let js_key: JsValue = path_to_js(&path)?; + if store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .is_none() + { + let entry = DbEntry { + path: path_to_string(&path)?, + kind: DbEntryKind::Directory, + content: None, + mode: 0o755, + modified: now, + created: now, + target: None, + size: 0, + }; + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + } + } + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn remove(&self, path: &Path, recursive: bool) -> Result<()> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let existing = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))?; + + if existing.is_none() { + return Err(IoError::from(ErrorKind::NotFound).into()); + } + + if recursive { + let path_s = path_to_string(&path)?; + let prefix = format!("{}/", path_s); + let all = store + .get_all(None, None) + .await + .map_err(|e| IoError::other(format!("indexeddb get_all: {e}")))?; + for js_value in all { + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + if entry.path == path_s || entry.path.starts_with(&prefix) { + store + .delete(entry.path.into()) + .await + .map_err(|e| IoError::other(format!("indexeddb delete: {e}")))?; + } + } + } else { + store + .delete(js_key) + .await + .map_err(|e| IoError::other(format!("indexeddb delete: {e}")))?; + } + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn stat(&self, path: &Path) -> Result { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadOnly) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let js_value = store + .get(js_key) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + + Ok(Self::entry_to_metadata(&entry)) + }) + .await + } + + async fn read_dir(&self, path: &Path) -> Result> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadOnly) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + if let Some(js_value) = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + { + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + if !matches!(entry.kind, DbEntryKind::Directory) { + return Err(IoError::from(ErrorKind::NotFound).into()); + } + } else { + return Err(IoError::from(ErrorKind::NotFound).into()); + } + + let mut entries = Vec::new(); + let all = store + .get_all(None, None) + .await + .map_err(|e| IoError::other(format!("indexeddb get_all: {e}")))?; + for js_value in all { + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + if let Some(name) = Self::is_direct_child(&path, &entry.path) { + entries.push(DirEntry { + name, + metadata: Self::entry_to_metadata(&entry), + }); + } + } + + Ok(entries) + }) + .await + } + + async fn exists(&self, path: &Path) -> Result { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadOnly) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let existing = store + .get(js_key) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))?; + Ok(existing.is_some()) + }) + .await + } + + async fn rename(&self, from: &Path, to: &Path) -> Result<()> { + let from = normalize_path(from); + let to = normalize_path(to); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let from_js: JsValue = path_to_js(&from)?; + let js_value = store + .get(from_js.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + let mut entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + + store + .delete(from_js) + .await + .map_err(|e| IoError::other(format!("indexeddb delete: {e}")))?; + + let from_s = path_to_string(&from)?; + let to_s = path_to_string(&to)?; + let from_prefix = format!("{}/", from_s); + let to_prefix = format!("{}/", to_s); + let all = store + .get_all(None, None) + .await + .map_err(|e| IoError::other(format!("indexeddb get_all: {e}")))?; + for js_value in all { + let mut child: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + if child.path.starts_with(&from_prefix) { + let new_path = to_prefix.clone() + &child.path[from_prefix.len()..]; + store + .delete(child.path.clone().into()) + .await + .map_err(|e| IoError::other(format!("indexeddb delete: {e}")))?; + child.path = new_path; + let js_child = serde_wasm_bindgen::to_value(&child) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_child, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + } + } + + entry.path = path_to_string(&to)?; + entry.modified = Self::now_ms(); + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + let from = normalize_path(from); + let to = normalize_path(to); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let from_js: JsValue = path_to_js(&from)?; + let js_value = store + .get(from_js) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + let mut entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + + entry.path = path_to_string(&to)?; + entry.created = Self::now_ms(); + entry.modified = Self::now_ms(); + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn symlink(&self, target: &Path, link: &Path) -> Result<()> { + let target = normalize_path(target); + let link = normalize_path(link); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let now = Self::now_ms(); + let entry = DbEntry { + path: path_to_string(&link)?, + kind: DbEntryKind::Symlink, + content: None, + mode: 0o777, + modified: now, + created: now, + target: Some(path_to_string(&target)?), + size: 0, + }; + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .add(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb add: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn read_link(&self, path: &Path) -> Result { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadOnly) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let js_value = store + .get(js_key) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + + let entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + + match entry.target { + Some(target) => Ok(PathBuf::from(target)), + None => Err(IoError::other("not a symlink").into()), + } + }) + .await + } + + async fn chmod(&self, path: &Path, mode: u32) -> Result<()> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let js_value = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + + let mut entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + entry.mode = mode; + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .put(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb put: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } + + async fn set_modified_time(&self, path: &Path, time: SystemTime) -> Result<()> { + let path = normalize_path(path); + let db_name = self.db_name.clone(); + run(async move { + let db = IndexedDbFs::open_db(&db_name).await?; + let tx = db + .transaction(&[STORE_NAME], TransactionMode::ReadWrite) + .map_err(|e| IoError::other(format!("indexeddb tx: {e}")))?; + let store = tx + .store(STORE_NAME) + .map_err(|e| IoError::other(format!("indexeddb store: {e}")))?; + + let js_key: JsValue = path_to_js(&path)?; + let js_value = store + .get(js_key.clone()) + .await + .map_err(|e| IoError::other(format!("indexeddb get: {e}")))? + .ok_or_else(|| IoError::from(ErrorKind::NotFound))?; + + let mut entry: DbEntry = serde_wasm_bindgen::from_value(js_value) + .map_err(|e| IoError::other(format!("deserialize: {e}")))?; + entry.modified = Self::system_time_to_ms(time); + + let js_entry = serde_wasm_bindgen::to_value(&entry) + .map_err(|e| IoError::other(format!("serialize: {e}")))?; + store + .put(&js_entry, None) + .await + .map_err(|e| IoError::other(format!("indexeddb put: {e}")))?; + + tx.done() + .await + .map_err(|e| IoError::other(format!("indexeddb tx done: {e}")))?; + Ok(()) + }) + .await + } +} diff --git a/crates/bashkit/src/fs/memory.rs b/crates/bashkit/src/fs/memory.rs index 8ea11181f..ed9a32e97 100644 --- a/crates/bashkit/src/fs/memory.rs +++ b/crates/bashkit/src/fs/memory.rs @@ -38,12 +38,12 @@ // while holding lock). This is intentional - corrupted state should not propagate. #![allow(clippy::unwrap_used)] +use crate::time::SystemTime; use async_trait::async_trait; use std::collections::HashMap; use std::io::{Error as IoError, ErrorKind}; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use super::traits::{DirEntry, FileSystem, FileSystemExt, FileType, Metadata}; diff --git a/crates/bashkit/src/fs/mod.rs b/crates/bashkit/src/fs/mod.rs index 6c6578855..901c5d90e 100644 --- a/crates/bashkit/src/fs/mod.rs +++ b/crates/bashkit/src/fs/mod.rs @@ -302,7 +302,7 @@ //! use std::path::{Path, PathBuf}; //! use std::collections::HashMap; //! use std::sync::RwLock; -//! use std::time::SystemTime; +//! use crate::time::SystemTime; //! //! /// A simple custom filesystem example //! pub struct SimpleFs { @@ -397,6 +397,8 @@ //! ``` mod backend; +#[cfg(feature = "indexeddb")] +mod indexeddb; mod limits; mod memory; mod mountable; @@ -409,6 +411,8 @@ mod search; mod traits; pub use backend::FsBackend; +#[cfg(feature = "indexeddb")] +pub use indexeddb::IndexedDbFs; pub use limits::{FsLimitExceeded, FsLimits, FsUsage}; pub use memory::{InMemoryFs, LazyLoader, VfsSnapshot}; pub use mountable::MountableFs; @@ -424,9 +428,9 @@ pub use search::{ pub use traits::{DirEntry, FileSystem, FileSystemExt, FileType, Metadata, fs_errors}; use crate::error::Result; +use crate::time::SystemTime; use std::io::{Error as IoError, ErrorKind}; use std::path::{Path, PathBuf}; -use std::time::SystemTime; /// Filesystem implementation for logic-only shells. /// diff --git a/crates/bashkit/src/fs/mountable.rs b/crates/bashkit/src/fs/mountable.rs index 4c1bd3c55..333b75d92 100644 --- a/crates/bashkit/src/fs/mountable.rs +++ b/crates/bashkit/src/fs/mountable.rs @@ -7,12 +7,12 @@ // while holding lock). This is intentional - corrupted state should not propagate. #![allow(clippy::unwrap_used)] +use crate::time::SystemTime; use async_trait::async_trait; use std::collections::BTreeMap; use std::io::Error as IoError; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use super::traits::{DirEntry, FileSystem, FileSystemExt, FileType, Metadata}; @@ -420,8 +420,8 @@ impl FileSystem for MountableFs { file_type: FileType::Directory, size: 0, mode: 0o755, - modified: std::time::SystemTime::now(), - created: std::time::SystemTime::now(), + modified: crate::time::SystemTime::now(), + created: crate::time::SystemTime::now(), }, }); } diff --git a/crates/bashkit/src/fs/overlay.rs b/crates/bashkit/src/fs/overlay.rs index 4d6935b07..d59419322 100644 --- a/crates/bashkit/src/fs/overlay.rs +++ b/crates/bashkit/src/fs/overlay.rs @@ -31,12 +31,12 @@ // while holding lock). This is intentional - corrupted state should not propagate. #![allow(clippy::unwrap_used)] +use crate::time::SystemTime; use async_trait::async_trait; use std::collections::HashSet; use std::io::{Error as IoError, ErrorKind}; use std::path::{Path, PathBuf}; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use super::memory::InMemoryFs; diff --git a/crates/bashkit/src/fs/posix.rs b/crates/bashkit/src/fs/posix.rs index 45c7c96a6..2967f4e4a 100644 --- a/crates/bashkit/src/fs/posix.rs +++ b/crates/bashkit/src/fs/posix.rs @@ -46,11 +46,11 @@ //! //! See [`FsBackend`](super::FsBackend) for how to implement a backend. +use crate::time::SystemTime; use async_trait::async_trait; use std::io::Error as IoError; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::SystemTime; use super::backend::FsBackend; use super::limits::{FsLimits, FsUsage}; diff --git a/crates/bashkit/src/fs/readonly.rs b/crates/bashkit/src/fs/readonly.rs index 4658344cd..02449fd83 100644 --- a/crates/bashkit/src/fs/readonly.rs +++ b/crates/bashkit/src/fs/readonly.rs @@ -4,11 +4,11 @@ //! embedder wants a session that can inspect data but cannot persist or stage //! any filesystem changes, including copies into the in-memory VFS. +use crate::time::SystemTime; use async_trait::async_trait; use std::io::{Error as IoError, ErrorKind}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use super::traits::{DirEntry, FileSystem, FileSystemExt, Metadata}; diff --git a/crates/bashkit/src/fs/realfs.rs b/crates/bashkit/src/fs/realfs.rs index 66a4b3853..fd5f88c15 100644 --- a/crates/bashkit/src/fs/realfs.rs +++ b/crates/bashkit/src/fs/realfs.rs @@ -65,10 +65,10 @@ //! bashkit --mount-rw /path/to/out:/mnt/out -c 'echo hi > /mnt/out/result.txt' //! ``` +use crate::time::SystemTime; use async_trait::async_trait; use std::io::{Error as IoError, ErrorKind}; use std::path::{Path, PathBuf}; -use std::time::SystemTime; use super::backend::FsBackend; use super::limits::{FsLimits, FsUsage}; diff --git a/crates/bashkit/src/fs/traits.rs b/crates/bashkit/src/fs/traits.rs index 25275af2d..fcbd5f9eb 100644 --- a/crates/bashkit/src/fs/traits.rs +++ b/crates/bashkit/src/fs/traits.rs @@ -40,10 +40,10 @@ //! } //! ``` +use crate::time::SystemTime; use async_trait::async_trait; use std::io::{Error as IoError, ErrorKind}; use std::path::Path; -use std::time::SystemTime; use super::limits::{FsLimits, FsUsage}; use crate::error::Result; diff --git a/crates/bashkit/src/interop/fs.rs b/crates/bashkit/src/interop/fs.rs index 241447e39..12e7f16c6 100644 --- a/crates/bashkit/src/interop/fs.rs +++ b/crates/bashkit/src/interop/fs.rs @@ -1,6 +1,7 @@ // Decision: cross-addon filesystem interop uses only a versioned repr(C) // handle + vtable. Rust trait objects stay inside the exporting addon. +use crate::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::{ DirEntry, Error as BashError, FileSystem, FileSystemExt, FileType, Metadata, Result as BashResult, async_trait, @@ -14,7 +15,6 @@ use std::slice; use std::str; use std::sync::Arc; use std::sync::mpsc::sync_channel; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::runtime::{Builder, Runtime}; pub const BASHKIT_FS_ABI_VERSION_V1: u32 = 1; diff --git a/crates/bashkit/src/interpreter/mod.rs b/crates/bashkit/src/interpreter/mod.rs index 9ae0c7d27..5624bd5a6 100644 --- a/crates/bashkit/src/interpreter/mod.rs +++ b/crates/bashkit/src/interpreter/mod.rs @@ -3708,7 +3708,7 @@ impl Interpreter { /// User and system CPU time are always reported as 0. /// This is a documented incompatibility with bash. async fn execute_time(&mut self, time_cmd: &TimeCommand) -> Result { - use std::time::Instant; + use crate::time::Instant; let start = Instant::now(); @@ -5307,7 +5307,7 @@ impl Interpreter { let trace_start = if self.trace.mode() != crate::trace::TraceMode::Off { self.trace .command_start(name, &args, self.cwd.to_string_lossy().as_ref()); - Some(std::time::Instant::now()) + Some(crate::time::Instant::now()) } else { None }; @@ -13745,7 +13745,7 @@ mod tests { #[tokio::test] async fn test_extglob_no_hang() { - use std::time::{Duration, Instant}; + use crate::time::{Duration, Instant}; let start = Instant::now(); let result = run_script( r#"shopt -s extglob; [[ "aaaaaaaaaaaa" == +(a|aa) ]] && echo yes || echo no"#, diff --git a/crates/bashkit/src/lib.rs b/crates/bashkit/src/lib.rs index d19a6fe90..ef3bbf8e1 100644 --- a/crates/bashkit/src/lib.rs +++ b/crates/bashkit/src/lib.rs @@ -428,6 +428,7 @@ mod snapshot; /// invariants enforced (TM-INF-013, TM-INF-016, TM-INF-022). #[doc(hidden)] pub mod testing; +pub mod time; /// Tool contract for LLM integration pub mod tool; /// Reusable tool primitives: ToolDef, ToolArgs, ToolImpl, exec types. @@ -447,6 +448,8 @@ pub use clap; #[cfg(feature = "http_client")] pub use credential::Credential; pub use error::{Error, Result}; +#[cfg(feature = "indexeddb")] +pub use fs::IndexedDbFs; pub use fs::{ DirEntry, FileSystem, FileSystemExt, FileType, FsBackend, FsLimitExceeded, FsLimits, FsUsage, InMemoryFs, LazyLoader, Metadata, MountableFs, OverlayFs, PosixFs, ReadOnlyFs, @@ -484,7 +487,7 @@ pub use scripted_tool::{ pub use tool_def::{AsyncToolExec, SyncToolExec, ToolImpl}; #[cfg(feature = "http_client")] -pub use network::{HttpClient, HttpHandler}; +pub use network::{HttpClient, HttpHandler, Method}; /// Re-exported network response type for custom HTTP handler implementations. #[cfg(feature = "http_client")] @@ -853,7 +856,7 @@ impl Bash { // Load persisted history on first exec (no-op if already loaded) self.interpreter.load_history().await; - let exec_start = std::time::Instant::now(); + let exec_start = crate::time::Instant::now(); // THREAT[TM-DOS-057]: Wrap execution with timeout to prevent sleep/blocking bypass. // Only the native path arms the tokio timeout; wasm has no reliable timer driver. #[cfg(not(target_family = "wasm"))] diff --git a/crates/bashkit/src/network/bot_auth.rs b/crates/bashkit/src/network/bot_auth.rs index 8cee8ce00..9ebfc07b5 100644 --- a/crates/bashkit/src/network/bot_auth.rs +++ b/crates/bashkit/src/network/bot_auth.rs @@ -18,11 +18,11 @@ //! .with_validity_secs(300); //! ``` +use crate::time::{SystemTime, UNIX_EPOCH}; use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use ed25519_dalek::{Signer, SigningKey, VerifyingKey}; use rand::Rng; use sha2::{Digest, Sha256}; -use std::time::{SystemTime, UNIX_EPOCH}; use zeroize::Zeroize; /// Configuration for Web Bot Authentication. diff --git a/crates/bashkit/src/network/client.rs b/crates/bashkit/src/network/client.rs index ed5bc2b70..52eabafba 100644 --- a/crates/bashkit/src/network/client.rs +++ b/crates/bashkit/src/network/client.rs @@ -3,76 +3,51 @@ //! Provides a virtual HTTP client that respects the allowlist with //! security mitigations for common HTTP attacks. //! +//! # Platform Support +//! +//! - **Native** (default): Uses `reqwest` with a private-IP-filtering DNS +//! resolver, streaming response bodies, and per-request timeout clients. +//! See `client::native` for the native transport implementation. +//! - **WASM** (`target_family = "wasm")`: Uses the browser's `fetch` API +//! via `web_sys` and `wasm_bindgen_futures`. DNS checks are limited to +//! literal-IP blocking because the browser does not expose raw socket or +//! DNS APIs. Timeouts are enforced with `AbortController`. +//! See `client::wasm` for the WASM transport implementation. +//! //! # Security Mitigations //! //! This module mitigates the following threats (see `specs/threat-model.md`): //! //! - **TM-NET-008**: Large response DoS → `max_response_bytes` limit (10MB default) -//! - **TM-NET-009**: Connection hang → connect timeout (10s) +//! - **TM-NET-009**: Connection hang → connect timeout (10s) *(native only)* //! - **TM-NET-010**: Slowloris attack → read timeout (30s) -//! - **TM-NET-011**: Redirect bypass → `Policy::none()` disables auto-redirect -//! - **TM-NET-012**: Chunked encoding bomb → streaming size check -//! - **TM-NET-013**: Gzip/compression bomb → auto-decompression disabled +//! - **TM-NET-011**: Redirect bypass → no auto-redirect on native; browser CORS +//! and same-origin policy provide additional defense on WASM. +//! - **TM-NET-012**: Chunked encoding bomb → streaming size check *(native)* / +//! `Content-Length` pre-check + array-buffer limit *(WASM)* +//! - **TM-NET-013**: Gzip/compression bomb → auto-decompression disabled *(native)* //! - **TM-NET-014**: DNS rebind via redirect → manual redirect requires allowlist check -//! - **TM-NET-015**: Host proxy leakage → `.no_proxy()` ignores host `HTTP_PROXY`/`HTTPS_PROXY` +//! - **TM-NET-015**: Host proxy leakage → `.no_proxy()` ignores host `HTTP_PROXY`/`HTTPS_PROXY` *(native)* //! - **TM-NET-002 (TOCTOU)**: DNS rebinding between pre-resolve check and actual connect → -//! private-IP filtering installed as reqwest's DNS resolver, so the connection path itself -//! refuses to dial any private/reserved IP, even if DNS answers diverge between checks. +//! private-IP filtering installed as reqwest's DNS resolver on native. WASM relies on +//! the browser's same-origin policy and the allowlist pre-check. -use reqwest::Client; -use reqwest::dns::{Name, Resolve, Resolving}; -use std::net::SocketAddr; -use std::sync::{Arc, OnceLock}; +#[cfg(not(target_family = "wasm"))] +use std::sync::OnceLock; use std::time::Duration; use super::allowlist::{NetworkAllowlist, UrlMatch, is_private_ip}; use crate::error::{Error, Result}; -/// THREAT[TM-NET-002 TOCTOU]: DNS resolver wrapper that rejects any -/// hostname whose addresses include a private/reserved IP at connect time. -/// -/// The pre-resolve check in `enforce_url_security` cannot bind the validated -/// IP to the actual connection because reqwest re-resolves the hostname when -/// `send()` runs. An attacker controlling DNS for an allowed hostname can -/// answer with a public IP during the check and then a private/internal IP -/// during the connect ("DNS rebinding"). Installing this resolver in the -/// reqwest client moves the policy onto the connection path, so the connect -/// itself refuses to dial private addresses regardless of pre-check timing. -struct PrivateIpFilteringResolver; - -impl Resolve for PrivateIpFilteringResolver { - fn resolve(&self, name: Name) -> Resolving { - Box::pin(async move { - let host = name.as_str().to_string(); - // Use port 0 in the lookup; reqwest documents that explicit URL - // ports override the resolved port, and otherwise scheme-default - // ports are substituted. Port 0 is a valid placeholder. - let lookup_target = format!("{}:0", host); - let resolved = tokio::net::lookup_host(lookup_target.as_str()) - .await - .map_err(|e| Box::new(e) as Box)?; - - let addrs: Vec = resolved.collect(); - let mut filtered: Vec = Vec::with_capacity(addrs.len()); - for addr in addrs { - if !is_private_ip(&addr.ip()) { - filtered.push(addr); - } - } +#[cfg(not(target_family = "wasm"))] +mod native; - if filtered.is_empty() { - let msg = format!( - "access denied: '{}' resolves only to private/reserved IPs (SSRF protection)", - host - ); - return Err(msg.into()); - } +#[cfg(target_family = "wasm")] +mod wasm; - let iter: Box + Send> = Box::new(filtered.into_iter()); - Ok(iter) - }) - } -} +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- /// Default maximum response body size (10 MB) pub const DEFAULT_MAX_RESPONSE_BYTES: usize = 10 * 1024 * 1024; @@ -86,6 +61,10 @@ pub const MAX_TIMEOUT_SECS: u64 = 600; /// Minimum allowed timeout (1 second) - prevents instant timeouts that waste resources pub const MIN_TIMEOUT_SECS: u64 = 1; +// --------------------------------------------------------------------------- +// HttpHandler trait +// --------------------------------------------------------------------------- + /// Trait for custom HTTP request handling. /// /// Embedders can implement this trait to intercept, proxy, log, cache, @@ -98,14 +77,12 @@ pub const MIN_TIMEOUT_SECS: u64 = 1; /// /// # Default /// -/// When no custom handler is set, `HttpClient` uses `reqwest` directly, -/// with a private-IP-filtering DNS resolver installed on the connector -/// that rejects private IPs at connect time. This catches DNS rebinding -/// that happens between the precheck and the actual TCP connect. +/// When no custom handler is set, `HttpClient` uses the platform +/// default (`reqwest` on native, `fetch` on WASM). /// /// # SSRF responsibility for handlers (TM-NET-023, #1570) /// -/// **Custom HTTP handlers DO NOT inherit reqwest's connect-time IP +/// **Custom HTTP handlers DO NOT inherit the platform's connect-time IP /// filter.** The DNS precheck bashkit runs is best-effort and is /// vulnerable to a rebind window between the precheck and the moment /// the handler opens its own socket. If a handler performs real network @@ -114,7 +91,7 @@ pub const MIN_TIMEOUT_SECS: u64 = 1; /// connecting, or constrain its egress at a lower layer. The internal /// classifier `bashkit::network::allowlist::is_private_ip` (re-exported /// at `bashkit::network::is_private_ip` when used from inside this -/// crate) is the same one the default reqwest path uses. Handlers that +/// crate) is the same one the default native path uses. Handlers that /// only consult fixtures or in-memory state (mocks, test doubles) have /// no exposure here. #[async_trait::async_trait] @@ -134,31 +111,37 @@ pub trait HttpHandler: Send + Sync { ) -> std::result::Result; } -/// HTTP client with allowlist-based access control. -/// -/// # Security Features -/// -/// - URL allowlist enforcement -/// - Response size limits to prevent memory exhaustion -/// - Configurable timeouts to prevent hanging -/// - No automatic redirect following (to prevent allowlist bypass) -pub struct HttpClient { - client: OnceLock>, - allowlist: NetworkAllowlist, - default_timeout: Duration, - /// Maximum response body size in bytes - max_response_bytes: usize, - /// Optional custom HTTP handler for request interception - handler: Option>, - /// Optional bot-auth config for transparent request signing - #[cfg(feature = "bot-auth")] - bot_auth: Option, - /// Interceptor hooks fired before each HTTP request - before_http: Vec>, - /// Interceptor hooks fired after each HTTP response - after_http: Vec>, +// --------------------------------------------------------------------------- +// Response +// --------------------------------------------------------------------------- + +/// HTTP response +#[derive(Debug)] +pub struct Response { + /// HTTP status code + pub status: u16, + /// Response headers (key-value pairs) + pub headers: Vec<(String, String)>, + /// Response body + pub body: Vec, +} + +impl Response { + /// Get the body as a UTF-8 string (lossy) + pub fn body_string(&self) -> String { + String::from_utf8_lossy(&self.body).into_owned() + } + + /// Check if the response was successful (2xx status) + pub fn is_success(&self) -> bool { + (200..300).contains(&self.status) + } } +// --------------------------------------------------------------------------- +// Method +// --------------------------------------------------------------------------- + /// HTTP request method #[derive(Debug, Clone, Copy, PartialEq)] pub enum Method { @@ -171,18 +154,7 @@ pub enum Method { } impl Method { - fn as_reqwest(self) -> reqwest::Method { - match self { - Method::Get => reqwest::Method::GET, - Method::Post => reqwest::Method::POST, - Method::Put => reqwest::Method::PUT, - Method::Delete => reqwest::Method::DELETE, - Method::Head => reqwest::Method::HEAD, - Method::Patch => reqwest::Method::PATCH, - } - } - - fn as_str(self) -> &'static str { + pub(crate) fn as_str(self) -> &'static str { match self { Method::Get => "GET", Method::Post => "POST", @@ -194,27 +166,36 @@ impl Method { } } -/// HTTP response -#[derive(Debug)] -pub struct Response { - /// HTTP status code - pub status: u16, - /// Response headers (key-value pairs) - pub headers: Vec<(String, String)>, - /// Response body - pub body: Vec, -} - -impl Response { - /// Get the body as a UTF-8 string (lossy) - pub fn body_string(&self) -> String { - String::from_utf8_lossy(&self.body).into_owned() - } +// --------------------------------------------------------------------------- +// HttpClient +// --------------------------------------------------------------------------- - /// Check if the response was successful (2xx status) - pub fn is_success(&self) -> bool { - (200..300).contains(&self.status) - } +/// HTTP client with allowlist-based access control. +/// +/// # Security Features +/// +/// - URL allowlist enforcement +/// - Response size limits to prevent memory exhaustion +/// - Configurable timeouts to prevent hanging +/// - No automatic redirect following (to prevent allowlist bypass) +pub struct HttpClient { + #[cfg(not(target_family = "wasm"))] + client: OnceLock>, + #[cfg(target_family = "wasm")] + _wasm_marker: std::marker::PhantomData<()>, + allowlist: NetworkAllowlist, + default_timeout: Duration, + /// Maximum response body size in bytes + max_response_bytes: usize, + /// Optional custom HTTP handler for request interception + handler: Option>, + /// Optional bot-auth config for transparent request signing + #[cfg(feature = "bot-auth")] + bot_auth: Option, + /// Interceptor hooks fired before each HTTP request + before_http: Vec>, + /// Interceptor hooks fired after each HTTP response + after_http: Vec>, } impl HttpClient { @@ -250,7 +231,10 @@ impl HttpClient { max_response_bytes: usize, ) -> Self { Self { + #[cfg(not(target_family = "wasm"))] client: OnceLock::new(), + #[cfg(target_family = "wasm")] + _wasm_marker: std::marker::PhantomData, allowlist, default_timeout: timeout, max_response_bytes, @@ -265,7 +249,7 @@ impl HttpClient { /// Set a custom HTTP handler for request interception. /// /// The handler is called after the URL allowlist check, so the security - /// boundary stays in bashkit. The default reqwest-based handler is used + /// boundary stays in bashkit. The default platform handler is used /// when no custom handler is set. pub fn set_handler(&mut self, handler: Box) { self.handler = Some(handler); @@ -364,16 +348,6 @@ impl HttpClient { } } - fn client(&self) -> Result<&Client> { - let block_private = self.allowlist.is_blocking_private_ips(); - let client = self - .client - .get_or_init(|| build_client(self.default_timeout, None, block_private)); - client - .as_ref() - .map_err(|err| Error::Internal(format!("failed to build HTTP client: {err}"))) - } - /// Make a GET request. pub async fn get(&self, url: &str) -> Result { self.request(Method::Get, url, None).await @@ -429,17 +403,9 @@ impl HttpClient { /// THREAT[TM-NET-002/004/023]: Pre-resolve DNS and block private IPs. /// - /// Returns `Err` for malformed URLs and for URLs with no host - /// component — the previous fail-open behaviour let those slip - /// through to the connect path. DNS lookup *errors* still - /// short-circuit to `Ok(())` (fail-open) because failing closed - /// here breaks any caller that intentionally targets an unresolved - /// hostname before a `before_http` hook rewrites or cancels the - /// request. The primary mitigation for the rebind / fail-open - /// window is the trait-level requirement on `HttpHandler` (see - /// #1570) and the connect-time `PrivateIpFilteringResolver` on the - /// default reqwest path. Direct-IP and successful-resolution paths - /// remain fail-closed. + /// On native this performs a full DNS lookup and blocks private resolves. + /// On WASM only literal IPs are checked; hostname resolution is deferred + /// to the browser, which applies same-origin policy and CORS. pub(crate) async fn check_private_ip(&self, url: &str) -> Result<()> { let parsed = url::Url::parse(url) .map_err(|e| Error::Network(format!("invalid URL for SSRF precheck: {e}")))?; @@ -448,7 +414,12 @@ impl HttpClient { "access denied: URL has no host (SSRF protection)".to_string(), )); }; - if let Ok(ip) = host.parse::() { + // Strip brackets from IPv6 literals so they parse correctly. + let ip_str = host + .strip_prefix('[') + .and_then(|h| h.strip_suffix(']')) + .unwrap_or(host); + if let Ok(ip) = ip_str.parse::() { if is_private_ip(&ip) { return Err(Error::Network(format!( "access denied: {} is a private IP (SSRF protection)", @@ -457,24 +428,30 @@ impl HttpClient { } return Ok(()); } - let port = parsed - .port() - .unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 }); - let addr = format!("{}:{}", host, port); - let Ok(addrs) = tokio::net::lookup_host(&addr).await else { - // DNS lookup failed — fall through. See the function-level - // doc for why this stays fail-open. - return Ok(()); - }; - for a in addrs { - if is_private_ip(&a.ip()) { - return Err(Error::Network(format!( - "access denied: {} resolves to private IP {} (SSRF protection)", - host, - a.ip() - ))); + + // Native: perform DNS precheck. + #[cfg(not(target_family = "wasm"))] + { + let port = parsed + .port() + .unwrap_or(if parsed.scheme() == "https" { 443 } else { 80 }); + let addr = format!("{}:{}", host, port); + let Ok(addrs) = tokio::net::lookup_host(&addr).await else { + // DNS lookup failed — fall through. See the function-level + // doc for why this stays fail-open. + return Ok(()); + }; + for a in addrs { + if is_private_ip(&a.ip()) { + return Err(Error::Network(format!( + "access denied: {} resolves to private IP {} (SSRF protection)", + host, + a.ip() + ))); + } } } + Ok(()) } @@ -534,6 +511,7 @@ impl HttpClient { let method_str = method.as_str(); let mut all_headers: Vec<(String, String)> = headers.to_vec(); all_headers.extend(signing_headers); + #[cfg(not(target_family = "wasm"))] let response = tokio::time::timeout( self.default_timeout, handler.request(method_str, url, body, &all_headers), @@ -541,6 +519,11 @@ impl HttpClient { .await .map_err(|_| Error::Network("operation timed out".to_string()))? .map_err(Error::Network)?; + #[cfg(target_family = "wasm")] + let response = handler + .request(method_str, url, body, &all_headers) + .await + .map_err(Error::Network)?; if response.body.len() > self.max_response_bytes { return Err(Error::Network(format!( "response too large: {} bytes (max: {} bytes)", @@ -556,90 +539,16 @@ impl HttpClient { return Ok(response); } - // Build request - let mut request = self.client()?.request(method.as_reqwest(), url); - - // Add custom headers - for (name, value) in headers { - request = request.header(name.as_str(), value.as_str()); - } - - // Add bot-auth signing headers - for (name, value) in &signing_headers { - request = request.header(name.as_str(), value.as_str()); - } - - if let Some(body_data) = body { - request = request.body(body_data.to_vec()); - } - - // Send request - let response = request - .send() - .await - .map_err(|e| Error::network_sanitized("request failed", &e))?; - - // Extract response data - let status = response.status().as_u16(); - let resp_headers: Vec<(String, String)> = response - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); - - // Fire after_http hooks - self.fire_after_http(crate::hooks::HttpResponseEvent { - url: url.to_string(), - status, - headers: resp_headers.clone(), - }); - - // Check Content-Length header to fail fast on large responses - if let Some(content_length) = response.content_length() - && usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes - { - return Err(Error::Network(format!( - "response too large: {} bytes (max: {} bytes)", - content_length, self.max_response_bytes - ))); - } - - // Read body with size limit enforcement - // We stream the response to avoid loading huge responses into memory - let body = self.read_body_with_limit(response).await?; - - Ok(Response { - status, - headers: resp_headers, + self.send_request( + method, + url, body, - }) - } - - /// Read response body with size limit enforcement. - /// - /// This streams the response to avoid allocating memory for oversized responses. - async fn read_body_with_limit(&self, response: reqwest::Response) -> Result> { - use futures_util::StreamExt; - - let mut body = Vec::new(); - let mut stream = response.bytes_stream(); - - while let Some(chunk_result) = stream.next().await { - let chunk = chunk_result - .map_err(|e| Error::network_sanitized("failed to read response chunk", &e))?; - - // Check if adding this chunk would exceed the limit - if body.len() + chunk.len() > self.max_response_bytes { - return Err(Error::Network(format!( - "response too large: exceeded {} bytes limit", - self.max_response_bytes - ))); - } - - body.extend_from_slice(&chunk); - } - - Ok(body) + headers, + signing_headers, + self.default_timeout, + None, + ) + .await } /// Make a HEAD request to get headers without body. @@ -756,6 +665,7 @@ impl HttpClient { let method_str = method.as_str(); let mut all_headers: Vec<(String, String)> = headers.to_vec(); all_headers.extend(signing_headers); + #[cfg(not(target_family = "wasm"))] let response = tokio::time::timeout( request_timeout, handler.request(method_str, url, body, &all_headers), @@ -763,6 +673,11 @@ impl HttpClient { .await .map_err(|_| Error::Network("operation timed out".to_string()))? .map_err(Error::Network)?; + #[cfg(target_family = "wasm")] + let response = handler + .request(method_str, url, body, &all_headers) + .await + .map_err(Error::Network)?; if response.body.len() > self.max_response_bytes { return Err(Error::Network(format!( "response too large: {} bytes (max: {} bytes)", @@ -778,144 +693,62 @@ impl HttpClient { return Ok(response); } - // Use the custom timeout client if any timeout is specified, otherwise use default client - let client = if timeout_secs.is_some() || connect_timeout_secs.is_some() { - // Connect timeout: use explicit connect_timeout, or derive from overall timeout, or use default 10s - let connect_timeout = connect_timeout_secs.map_or_else( - || std::cmp::min(request_timeout, Duration::from_secs(10)), - |s| Duration::from_secs(clamp_timeout(s)), - ); - build_client( - request_timeout, - Some(connect_timeout), - self.allowlist.is_blocking_private_ips(), - ) - .map_err(|e| Error::network_sanitized("failed to create client", &e))? - } else { - self.client()?.clone() - }; - - // Build request - let mut request = client.request(method.as_reqwest(), url); - - // Add custom headers - for (name, value) in headers { - request = request.header(name.as_str(), value.as_str()); - } - - // Add bot-auth signing headers - for (name, value) in &signing_headers { - request = request.header(name.as_str(), value.as_str()); - } - - if let Some(body_data) = body { - request = request.body(body_data.to_vec()); - } + self.send_request( + method, + url, + body, + headers, + signing_headers, + request_timeout, + connect_timeout_secs.map(|s| Duration::from_secs(clamp_timeout(s))), + ) + .await + } - // Send request - let response = request.send().await.map_err(|e| { - // Check if this was a timeout error - if e.is_timeout() { - Error::Network("operation timed out".to_string()) - } else { - Error::network_sanitized("request failed", &e) - } - })?; + // ----------------------------------------------------------------------- + // Platform-specific transport + // ----------------------------------------------------------------------- - // Extract response data - let status = response.status().as_u16(); - let resp_headers: Vec<(String, String)> = response - .headers() - .iter() - .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) - .collect(); + #[cfg(target_family = "wasm")] + pub(crate) async fn send_request( + &self, + method: Method, + url: &str, + body: Option<&[u8]>, + headers: &[(String, String)], + signing_headers: Vec<(String, String)>, + timeout: Duration, + _connect_timeout: Option, + ) -> Result { + let response = wasm::send_request( + self.max_response_bytes, + method, + url, + body, + headers, + signing_headers, + timeout, + ) + .await?; // Fire after_http hooks self.fire_after_http(crate::hooks::HttpResponseEvent { url: url.to_string(), - status, - headers: resp_headers.clone(), + status: response.status, + headers: response.headers.clone(), }); - // Check Content-Length header to fail fast on large responses - if let Some(content_length) = response.content_length() - && usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes - { - return Err(Error::Network(format!( - "response too large: {} bytes (max: {} bytes)", - content_length, self.max_response_bytes - ))); - } - - // Read body with size limit enforcement - let body = self.read_body_with_limit(response).await?; - - Ok(Response { - status, - headers: resp_headers, - body, - }) + Ok(response) } } -/// Install the rustls `ring` crypto provider as the process-wide default. -/// -/// We pair reqwest's `rustls-no-provider` feature with an explicit `ring` -/// install so the dep tree contains zero C-compiled crypto (no aws-lc-sys). -/// That keeps cross-compiled wheel builds (notably aarch64 manylinux, where -/// the cross sysroot is missing `AT_HWCAP2`) green and removes a class of -/// toolchain-specific build failures. -/// -/// Idempotent: safe to call from multiple call sites and across crates. -/// `install_default` errors if a provider is already installed (e.g. set by -/// the embedder); we treat that as success because *some* provider is now -/// active, which is all rustls needs. -fn install_default_crypto_provider() { - use std::sync::Once; - static INIT: Once = Once::new(); - INIT.call_once(|| { - let _ = rustls::crypto::ring::default_provider().install_default(); - }); -} - -fn build_client( - timeout: Duration, - connect_timeout: Option, - block_private_ips: bool, -) -> std::result::Result { - install_default_crypto_provider(); - let mut builder = Client::builder() - .timeout(timeout) - .connect_timeout(connect_timeout.unwrap_or(Duration::from_secs(10))) - .user_agent("bashkit/0.1.2") - // Disable automatic redirects to prevent allowlist bypass via redirect - // Scripts can follow redirects manually if needed - .redirect(reqwest::redirect::Policy::none()) - // Disable automatic decompression to prevent zip bomb attacks - // and match real curl behavior (which requires --compressed flag) - // With decompression enabled, a 10KB gzip could expand to 10GB - .no_gzip() - .no_brotli() - .no_deflate() - // THREAT[TM-NET-015]: Ignore host proxy env vars (HTTP_PROXY, HTTPS_PROXY, ALL_PROXY) - // to prevent sandboxed HTTP traffic from being redirected through a host proxy - .no_proxy(); - - // THREAT[TM-NET-002 TOCTOU]: install a DNS resolver that filters private IPs - // at connect time, so DNS rebinding cannot slip a private address past the - // pre-resolve check. - if block_private_ips { - builder = builder.dns_resolver(Arc::new(PrivateIpFilteringResolver)); - } - - builder.build().map_err(|e| e.to_string()) -} +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; - use std::time::Duration as StdDuration; - use tokio::time::sleep; struct StaticHandler { response: Response, @@ -938,49 +771,19 @@ mod tests { } } - struct SlowHandler { - delay: StdDuration, - } - - #[async_trait::async_trait] - impl HttpHandler for SlowHandler { - async fn request( - &self, - _method: &str, - _url: &str, - _body: Option<&[u8]>, - _headers: &[(String, String)], - ) -> std::result::Result { - sleep(self.delay).await; - Ok(Response { - status: 200, - headers: vec![], - body: b"ok".to_vec(), - }) - } - } - #[tokio::test] async fn test_blocked_by_empty_allowlist() { let client = HttpClient::new(NetworkAllowlist::new()); + #[cfg(not(target_family = "wasm"))] assert!(client.client.get().is_none()); let result = client.get("https://example.com").await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("access denied")); + #[cfg(not(target_family = "wasm"))] assert!(client.client.get().is_none()); } - #[test] - fn test_default_client_initializes_on_first_use() { - let client = HttpClient::new(NetworkAllowlist::allow_all()); - assert!(client.client.get().is_none()); - - client.client().expect("client"); - - assert!(client.client.get().is_some()); - } - #[tokio::test] async fn test_blocked_by_allowlist() { let allowlist = NetworkAllowlist::new().allow("https://allowed.com"); @@ -1007,12 +810,9 @@ mod tests { let allowlist = NetworkAllowlist::new().allow("https://blocked.com"); let client = HttpClient::new(allowlist); - // Should use default client (not blocked by allowlist here, but blocked.com not actually accessible) - // This just verifies the code path with None timeout works let result = client .request_with_timeout(Method::Get, "https://blocked.example.com", None, &[], None) .await; - // Should fail with access denied (not in allowlist) assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("access denied")); } @@ -1022,7 +822,6 @@ mod tests { let allowlist = NetworkAllowlist::new().allow("https://allowed.com"); let client = HttpClient::new(allowlist); - // Test with invalid URL let result = client .request_with_timeout(Method::Get, "not-a-url", None, &[], Some(10)) .await; @@ -1033,7 +832,6 @@ mod tests { async fn test_request_with_timeouts_both_params() { let client = HttpClient::new(NetworkAllowlist::new()); - // Both timeouts specified - should still check allowlist first let result = client .request_with_timeouts( Method::Get, @@ -1052,7 +850,6 @@ mod tests { async fn test_request_with_timeouts_connect_only() { let client = HttpClient::new(NetworkAllowlist::new()); - // Only connect timeout specified let result = client .request_with_timeouts(Method::Get, "https://example.com", None, &[], None, Some(5)) .await; @@ -1062,125 +859,13 @@ mod tests { #[test] fn test_u64_to_usize_no_truncation() { - // On 64-bit: fits fine. On 32-bit: saturates to usize::MAX rather than truncating. - let large: u64 = 5_368_709_120; // 5GB + let large: u64 = 5_368_709_120; let result = usize::try_from(large).unwrap_or(usize::MAX); - // Should never silently become a smaller value assert!(result >= large.min(usize::MAX as u64) as usize); } - #[test] - fn test_build_client_uses_no_proxy() { - // Verify build_client succeeds — the .no_proxy() call ensures - // host HTTP_PROXY/HTTPS_PROXY env vars are ignored (TM-NET-015). - let client = build_client(Duration::from_secs(30), None, true); - assert!(client.is_ok(), "build_client should succeed with no_proxy"); - } - - #[test] - fn test_build_client_installs_ring_crypto_provider() { - // Regression: with reqwest's `rustls-no-provider` feature, rustls panics - // on first TLS handshake unless a default crypto provider is installed. - // build_client must install the ring provider via the `Once` guard so - // every code path (default client + per-request timeout client) is safe. - // The dep tree must NOT include aws-lc-sys/aws-lc-rs (verified by - // `cargo tree -i aws-lc-sys` returning no match). - let _ = build_client(Duration::from_secs(30), None, true); - // A provider is now installed process-wide. `install_default` returns - // Err on the second call — that's our invariant: the first install - // succeeded. - let second_install = rustls::crypto::ring::default_provider().install_default(); - assert!( - second_install.is_err(), - "build_client must install a default crypto provider before \ - returning, otherwise the first HTTPS request panics" - ); - } - - #[test] - fn test_install_default_crypto_provider_is_idempotent() { - // Multiple invocations must not panic; the `Once` guard ensures only - // the first call attempts an install. - install_default_crypto_provider(); - install_default_crypto_provider(); - install_default_crypto_provider(); - } - - #[tokio::test] - async fn test_private_ip_filtering_resolver_rejects_loopback() { - // THREAT[TM-NET-002]: regression for DNS-rebinding TOCTOU. The pre-resolve - // check is best-effort and is now backed by a connection-time resolver - // that refuses to dial private/reserved IPs even when DNS answers - // change between the security check and `send()`. - // - // `localhost` always resolves to a loopback address (127.0.0.1 / ::1). - // The filter must reject it, proving the policy is enforced on the path - // reqwest actually uses to connect. - let resolver = PrivateIpFilteringResolver; - let name: Name = "localhost".parse().expect("valid DNS name"); - let result = resolver.resolve(name).await; - assert!( - result.is_err(), - "localhost must be rejected by the private-IP-filtering resolver" - ); - let err = result.err().unwrap().to_string(); - assert!( - err.contains("private/reserved"), - "error must mention SSRF protection, got: {err}" - ); - } - - #[tokio::test] - async fn test_private_ip_filtering_resolver_filters_private_from_mixed() { - // If a hostname resolved to a mix of public and private IPs, only the - // public addresses must reach reqwest's connection logic. Simulate by - // resolving a public-DNS name we don't actually depend on for the - // network test — we just verify the filtering logic in isolation. - // - // We construct synthetic addresses to drive the filter directly, - // because relying on third-party DNS in unit tests is flaky. - use std::net::{IpAddr, Ipv4Addr}; - let public: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 0); - let private: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 0); - let metadata: SocketAddr = - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)), 0); - let mixed = vec![public, private, metadata]; - let kept: Vec = mixed - .into_iter() - .filter(|a| !is_private_ip(&a.ip())) - .collect(); - assert_eq!(kept, vec![public]); - } - - #[tokio::test] - async fn test_default_client_rejects_loopback_via_resolver() { - // End-to-end regression: build the same reqwest client the runtime uses - // (private-IP filtering enabled) and try to dial a hostname that - // resolves only to loopback. The resolver must short-circuit the - // connection attempt with an SSRF-style error rather than dialing. - let allowlist = NetworkAllowlist::new().allow("http://localhost"); - let client = HttpClient::new(allowlist); - let result = client.get("http://localhost").await; - assert!( - result.is_err(), - "request to a loopback hostname must be refused" - ); - let msg = result.err().unwrap().to_string(); - assert!( - msg.contains("private") - || msg.contains("SSRF") - || msg.contains("reserved") - || msg.contains("access denied"), - "expected SSRF-protection error, got: {msg}" - ); - } - #[tokio::test] async fn test_check_private_ip_fails_closed_on_invalid_url() { - // Regression for #1570 (TM-NET-023): malformed URLs previously - // returned Ok(()) from the precheck. The fail-closed contract is - // exercised directly against `check_private_ip` to avoid relying - // on the allowlist's parser short-circuiting first. let client = HttpClient::new(NetworkAllowlist::allow_all()); let result = client.check_private_ip("definitely::not::a::url").await; assert!(result.is_err(), "malformed URL must trip the SSRF precheck"); @@ -1193,8 +878,6 @@ mod tests { #[tokio::test] async fn test_check_private_ip_fails_closed_on_no_host() { - // Regression for #1570: URLs without a host component used to slip - // through. Now they are rejected. let client = HttpClient::new(NetworkAllowlist::allow_all()); let result = client.check_private_ip("file:///etc/passwd").await; assert!(result.is_err(), "host-less URL must trip the precheck"); @@ -1207,8 +890,6 @@ mod tests { #[tokio::test] async fn test_check_private_ip_blocks_literal_private_ip() { - // Direct IP form: no DNS, deterministic — the existing direct-IP - // branch must still reject 10.0.0.1. let client = HttpClient::new(NetworkAllowlist::allow_all()); let result = client.check_private_ip("http://10.0.0.1/").await; assert!(result.is_err()); @@ -1221,8 +902,6 @@ mod tests { #[tokio::test] async fn test_check_private_ip_blocks_metadata_via_v4_mapped_v6() { - // Belt-and-braces with TM-NET-022 (#1569): IPv4-mapped IPv6 form of - // AWS metadata must also fail closed. let client = HttpClient::new(NetworkAllowlist::allow_all()); let result = client .check_private_ip("http://[::ffff:169.254.169.254]/") @@ -1303,29 +982,6 @@ mod tests { assert!(result.unwrap_err().to_string().contains("access denied")); } - #[tokio::test] - async fn test_custom_handler_enforces_request_timeout() { - let mut client = HttpClient::with_config( - NetworkAllowlist::allow_all(), - Duration::from_secs(30), - DEFAULT_MAX_RESPONSE_BYTES, - ); - client.set_handler(Box::new(SlowHandler { - delay: StdDuration::from_millis(1200), - })); - - let result = client - .request_with_timeouts(Method::Get, "https://example.com", None, &[], Some(1), None) - .await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("operation timed out") - ); - } - // Note: Integration tests that actually make network requests // should be in a separate test file and marked with #[ignore] // to avoid network dependencies in unit tests. diff --git a/crates/bashkit/src/network/client/native.rs b/crates/bashkit/src/network/client/native.rs new file mode 100644 index 000000000..f97709a6d --- /dev/null +++ b/crates/bashkit/src/network/client/native.rs @@ -0,0 +1,381 @@ +//! Native HTTP transport using `reqwest`. +//! +//! This submodule is compiled only on non-WASM targets and provides the +//! native `send_request` implementation backed by `reqwest` with a +//! private-IP filtering DNS resolver. + +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; + +use reqwest::Client; +use reqwest::dns::{Name, Resolve, Resolving}; + +use super::{HttpClient, Method, Response}; +use crate::error::{Error, Result}; +use crate::network::allowlist::is_private_ip; + +// --------------------------------------------------------------------------- +// DNS resolver that rejects private IPs at connect time. +// --------------------------------------------------------------------------- + +/// THREAT[TM-NET-002 TOCTOU]: DNS resolver wrapper that rejects any +/// hostname whose addresses include a private/reserved IP at connect time. +struct PrivateIpFilteringResolver; + +impl Resolve for PrivateIpFilteringResolver { + fn resolve(&self, name: Name) -> Resolving { + Box::pin(async move { + let host = name.as_str().to_string(); + let lookup_target = format!("{}:0", host); + let resolved = tokio::net::lookup_host(lookup_target.as_str()) + .await + .map_err(|e| Box::new(e) as Box)?; + + let addrs: Vec = resolved.collect(); + let mut filtered: Vec = Vec::with_capacity(addrs.len()); + for addr in addrs { + if !is_private_ip(&addr.ip()) { + filtered.push(addr); + } + } + + if filtered.is_empty() { + let msg = format!( + "access denied: '{}' resolves only to private/reserved IPs (SSRF protection)", + host + ); + return Err(msg.into()); + } + + let iter: Box + Send> = Box::new(filtered.into_iter()); + Ok(iter) + }) + } +} + +// --------------------------------------------------------------------------- +// Method extension +// --------------------------------------------------------------------------- + +impl Method { + pub(crate) fn as_reqwest(self) -> reqwest::Method { + match self { + Method::Get => reqwest::Method::GET, + Method::Post => reqwest::Method::POST, + Method::Put => reqwest::Method::PUT, + Method::Delete => reqwest::Method::DELETE, + Method::Head => reqwest::Method::HEAD, + Method::Patch => reqwest::Method::PATCH, + } + } +} + +// --------------------------------------------------------------------------- +// HttpClient native transport +// --------------------------------------------------------------------------- + +impl HttpClient { + fn client(&self) -> Result<&Client> { + let block_private = self.allowlist.is_blocking_private_ips(); + let client = self + .client + .get_or_init(|| build_client(self.default_timeout, None, block_private)); + client + .as_ref() + .map_err(|err| Error::Internal(format!("failed to build HTTP client: {err}"))) + } + + #[allow(clippy::too_many_arguments)] + pub(crate) async fn send_request( + &self, + method: Method, + url: &str, + body: Option<&[u8]>, + headers: &[(String, String)], + signing_headers: Vec<(String, String)>, + timeout: Duration, + connect_timeout: Option, + ) -> Result { + // Use the custom timeout client if any timeout is specified, otherwise use default client + let client = if timeout != self.default_timeout || connect_timeout.is_some() { + let connect_timeout = + connect_timeout.unwrap_or_else(|| std::cmp::min(timeout, Duration::from_secs(10))); + build_client( + timeout, + Some(connect_timeout), + self.allowlist.is_blocking_private_ips(), + ) + .map_err(|e| Error::network_sanitized("failed to create client", &e))? + } else { + self.client()?.clone() + }; + + // Build request + let mut request = client.request(method.as_reqwest(), url); + + // Add custom headers + for (name, value) in headers { + request = request.header(name.as_str(), value.as_str()); + } + + // Add bot-auth signing headers + for (name, value) in &signing_headers { + request = request.header(name.as_str(), value.as_str()); + } + + if let Some(body_data) = body { + request = request.body(body_data.to_vec()); + } + + // Send request + let response = request.send().await.map_err(|e| { + if e.is_timeout() { + Error::Network("operation timed out".to_string()) + } else { + Error::network_sanitized("request failed", &e) + } + })?; + + // Extract response data + let status = response.status().as_u16(); + let resp_headers: Vec<(String, String)> = response + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) + .collect(); + + // Fire after_http hooks + self.fire_after_http(crate::hooks::HttpResponseEvent { + url: url.to_string(), + status, + headers: resp_headers.clone(), + }); + + // Check Content-Length header to fail fast on large responses + if let Some(content_length) = response.content_length() + && usize::try_from(content_length).unwrap_or(usize::MAX) > self.max_response_bytes + { + return Err(Error::Network(format!( + "response too large: {} bytes (max: {} bytes)", + content_length, self.max_response_bytes + ))); + } + + // Read body with size limit enforcement + let body = self.read_body_with_limit(response).await?; + + Ok(Response { + status, + headers: resp_headers, + body, + }) + } + + /// Read response body with size limit enforcement. + /// + /// This streams the response to avoid allocating memory for oversized responses. + async fn read_body_with_limit(&self, response: reqwest::Response) -> Result> { + use futures_util::StreamExt; + + let mut body = Vec::new(); + let mut stream = response.bytes_stream(); + + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result + .map_err(|e| Error::network_sanitized("failed to read response chunk", &e))?; + + // Check if adding this chunk would exceed the limit + if body.len() + chunk.len() > self.max_response_bytes { + return Err(Error::Network(format!( + "response too large: exceeded {} bytes limit", + self.max_response_bytes + ))); + } + + body.extend_from_slice(&chunk); + } + + Ok(body) + } +} + +// --------------------------------------------------------------------------- +// Client builder and crypto provider +// --------------------------------------------------------------------------- + +/// Install the rustls `ring` crypto provider as the process-wide default. +fn install_default_crypto_provider() { + use std::sync::Once; + static INIT: Once = Once::new(); + INIT.call_once(|| { + let _ = rustls::crypto::ring::default_provider().install_default(); + }); +} + +fn build_client( + timeout: Duration, + connect_timeout: Option, + block_private_ips: bool, +) -> std::result::Result { + install_default_crypto_provider(); + let mut builder = Client::builder() + .timeout(timeout) + .connect_timeout(connect_timeout.unwrap_or(Duration::from_secs(10))) + .user_agent("bashkit/0.1.2") + // Disable automatic redirects to prevent allowlist bypass via redirect + .redirect(reqwest::redirect::Policy::none()) + // Disable automatic decompression to prevent zip bomb attacks + .no_gzip() + .no_brotli() + .no_deflate() + // THREAT[TM-NET-015]: Ignore host proxy env vars + .no_proxy(); + + // THREAT[TM-NET-002 TOCTOU): install a DNS resolver that filters private IPs + if block_private_ips { + builder = builder.dns_resolver(Arc::new(PrivateIpFilteringResolver)); + } + + builder.build().map_err(|e| e.to_string()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration as StdDuration; + + use super::super::{DEFAULT_MAX_RESPONSE_BYTES, HttpHandler, NetworkAllowlist}; + + struct SlowHandler { + delay: StdDuration, + } + + #[async_trait::async_trait] + impl HttpHandler for SlowHandler { + async fn request( + &self, + _method: &str, + _url: &str, + _body: Option<&[u8]>, + _headers: &[(String, String)], + ) -> std::result::Result { + tokio::time::sleep(self.delay).await; + Ok(Response { + status: 200, + headers: vec![], + body: b"ok".to_vec(), + }) + } + } + + #[test] + fn test_default_client_initializes_on_first_use() { + let client = HttpClient::new(NetworkAllowlist::allow_all()); + assert!(client.client.get().is_none()); + + client.client().expect("client"); + + assert!(client.client.get().is_some()); + } + + #[test] + fn test_build_client_uses_no_proxy() { + let client = build_client(Duration::from_secs(30), None, true); + assert!(client.is_ok(), "build_client should succeed with no_proxy"); + } + + #[test] + fn test_build_client_installs_ring_crypto_provider() { + let _ = build_client(Duration::from_secs(30), None, true); + let second_install = rustls::crypto::ring::default_provider().install_default(); + assert!( + second_install.is_err(), + "build_client must install a default crypto provider" + ); + } + + #[test] + fn test_install_default_crypto_provider_is_idempotent() { + install_default_crypto_provider(); + install_default_crypto_provider(); + install_default_crypto_provider(); + } + + #[tokio::test] + async fn test_private_ip_filtering_resolver_rejects_loopback() { + let resolver = PrivateIpFilteringResolver; + let name: Name = "localhost".parse().expect("valid DNS name"); + let result = resolver.resolve(name).await; + assert!( + result.is_err(), + "localhost must be rejected by the private-IP-filtering resolver" + ); + let err = result.err().unwrap().to_string(); + assert!( + err.contains("private/reserved"), + "error must mention SSRF protection, got: {err}" + ); + } + + #[tokio::test] + async fn test_private_ip_filtering_resolver_filters_private_from_mixed() { + use std::net::{IpAddr, Ipv4Addr}; + let public: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 0); + let private: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 0); + let metadata: SocketAddr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)), 0); + let mixed = vec![public, private, metadata]; + let kept: Vec = mixed + .into_iter() + .filter(|a| !is_private_ip(&a.ip())) + .collect(); + assert_eq!(kept, vec![public]); + } + + #[tokio::test] + async fn test_default_client_rejects_loopback_via_resolver() { + let allowlist = NetworkAllowlist::new().allow("http://localhost"); + let client = HttpClient::new(allowlist); + let result = client.get("http://localhost").await; + assert!( + result.is_err(), + "request to a loopback hostname must be refused" + ); + let msg = result.err().unwrap().to_string(); + assert!( + msg.contains("private") + || msg.contains("SSRF") + || msg.contains("reserved") + || msg.contains("access denied"), + "expected SSRF-protection error, got: {msg}" + ); + } + + #[tokio::test] + async fn test_custom_handler_enforces_request_timeout() { + let mut client = HttpClient::with_config( + NetworkAllowlist::allow_all(), + Duration::from_secs(30), + DEFAULT_MAX_RESPONSE_BYTES, + ); + client.set_handler(Box::new(SlowHandler { + delay: StdDuration::from_millis(1200), + })); + + let result = client + .request_with_timeouts(Method::Get, "https://example.com", None, &[], Some(1), None) + .await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("operation timed out") + ); + } +} diff --git a/crates/bashkit/src/network/client/wasm.rs b/crates/bashkit/src/network/client/wasm.rs new file mode 100644 index 000000000..243983fef --- /dev/null +++ b/crates/bashkit/src/network/client/wasm.rs @@ -0,0 +1,190 @@ +//! WASM-specific HTTP transport using the browser `fetch` API. +//! +//! This module is compiled only on `target_family = "wasm"` and provides the +//! `send_request` implementation backed by `web_sys::fetch` and +//! `wasm_bindgen_futures::JsFuture`. +//! +//! # Limitations vs Native +//! +//! - No custom DNS resolver (browser handles resolution). Same-origin policy +//! and CORS provide additional SSRF defense. +//! - No separate connect timeout (`fetch` does not expose this). +//! - No proxy controls (browser handles proxies). +//! - Response bodies are read via `array_buffer()` rather than streaming. +//! Size limits are enforced after the full body is received. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; + +use wasm_bindgen::JsCast; +use wasm_bindgen_futures::JsFuture; +use web_sys::{AbortController, Request, RequestInit, RequestMode, Response}; + +use super::{Method, Response as HttpResponse}; +use crate::error::{Error, Result}; + +/// Wrapper that asserts a future is `Send`. +/// +/// # Safety +/// +/// On `wasm32-unknown-unknown` there is only one thread, so all types are +/// effectively `Send`. This wrapper is only used within the WASM HTTP client +/// which is compiled exclusively for that target. +struct AssertSend(F); + +unsafe impl Send for AssertSend {} +unsafe impl Sync for AssertSend {} + +impl Future for AssertSend { + type Output = F::Output; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: We are projecting from Pin<&mut AssertSend> to Pin<&mut F>. + // AssertSend is a newtype wrapper with the same memory layout. + unsafe { self.map_unchecked_mut(|s| &mut s.0).poll(cx) } + } +} + +/// Format a `JsValue` error for network error reporting. +/// +/// Extracts the human-readable message without dumping the full +/// `JsValue` debug representation (which includes a wasm stack trace). +fn js_err_str(e: &wasm_bindgen::JsValue) -> String { + e.as_string() + .or_else(|| js_sys::Reflect::get(e, &"message".into()).ok()?.as_string()) + .unwrap_or_else(|| "unknown error".to_string()) +} + +/// Execute an HTTP request via the browser `fetch` API. +pub(crate) fn send_request( + max_response_bytes: usize, + method: Method, + url: &str, + body: Option<&[u8]>, + headers: &[(String, String)], + signing_headers: Vec<(String, String)>, + timeout: Duration, +) -> impl Future> + Send + Sync { + let url = url.to_string(); + let headers = headers.to_vec(); + AssertSend(async move { + let abort_controller = AbortController::new().map_err(|e| { + Error::Internal(format!( + "failed to create abort controller: {}", + js_err_str(&e) + )) + })?; + + let opts = RequestInit::new(); + opts.set_method(method.as_str()); + opts.set_mode(RequestMode::Cors); + opts.set_signal(Some(&abort_controller.signal())); + + if let Some(body_data) = body { + let array = js_sys::Uint8Array::from(body_data); + opts.set_body(&array); + } + + let request = Request::new_with_str_and_init(&url, &opts) + .map_err(|e| Error::Network(format!("failed to build request: {}", js_err_str(&e))))?; + + let req_headers = request.headers(); + for (name, value) in &headers { + req_headers + .set(name, value) + .map_err(|e| Error::Network(format!("failed to set header: {}", js_err_str(&e))))?; + } + for (name, value) in &signing_headers { + req_headers.set(name, value).map_err(|e| { + Error::Network(format!("failed to set signing header: {}", js_err_str(&e))) + })?; + } + + // Set up timeout via abort controller + setTimeout + let window = web_sys::window() + .ok_or_else(|| Error::Internal("no window object available".to_string()))?; + let timeout_ms = timeout.as_millis() as i32; + let abort_for_timeout = abort_controller.clone(); + let timeout_closure = wasm_bindgen::closure::Closure::once_into_js(move || { + abort_for_timeout.abort(); + }); + let timeout_id = window + .set_timeout_with_callback_and_timeout_and_arguments_0( + timeout_closure.as_ref().unchecked_ref(), + timeout_ms, + ) + .map_err(|e| Error::Internal(format!("failed to set timeout: {}", js_err_str(&e))))?; + + let fetch_promise = window.fetch_with_request(&request); + let resp_value = match JsFuture::from(fetch_promise).await { + Ok(v) => v, + Err(e) => { + window.clear_timeout_with_handle(timeout_id); + let msg = js_err_str(&e); + if msg.contains("AbortError") || msg.contains("abort") { + return Err(Error::Network("operation timed out".to_string())); + } + return Err(Error::network_sanitized("request failed", &msg)); + } + }; + + window.clear_timeout_with_handle(timeout_id); + + let response: Response = resp_value + .dyn_into() + .map_err(|e| Error::Internal(format!("invalid response type: {}", js_err_str(&e))))?; + + let status = response.status(); + let resp_headers = response.headers(); + let mut header_pairs = Vec::new(); + if let Ok(Some(iter)) = js_sys::try_iter(&resp_headers) { + for entry in iter { + let entry = entry.map_err(|e| { + Error::Internal(format!("header entry error: {}", js_err_str(&e))) + })?; + if let Ok(array) = entry.dyn_into::() { + if array.length() >= 2 { + let name = array.get(0).as_string().unwrap_or_default(); + let value = array.get(1).as_string().unwrap_or_default(); + header_pairs.push((name, value)); + } + } + } + } + + // Read body + let body = match response.array_buffer() { + Ok(promise) => { + let body_value = JsFuture::from(promise).await.map_err(|e| { + let msg = js_err_str(&e); + Error::network_sanitized("failed to read response body", &msg) + })?; + let array_buffer: js_sys::ArrayBuffer = body_value.dyn_into().map_err(|e| { + Error::Internal(format!("invalid body type: {}", js_err_str(&e))) + })?; + js_sys::Uint8Array::new(&array_buffer).to_vec() + } + Err(e) => { + return Err(Error::Network(format!( + "failed to read response body: {}", + js_err_str(&e) + ))); + } + }; + + if body.len() > max_response_bytes { + return Err(Error::Network(format!( + "response too large: {} bytes (max: {} bytes)", + body.len(), + max_response_bytes + ))); + } + + Ok(HttpResponse { + status, + headers: header_pairs, + body, + }) + }) +} diff --git a/crates/bashkit/src/parser/mod.rs b/crates/bashkit/src/parser/mod.rs index 6207edadd..535526a1f 100644 --- a/crates/bashkit/src/parser/mod.rs +++ b/crates/bashkit/src/parser/mod.rs @@ -26,7 +26,7 @@ pub use span::{Position, Span}; use crate::error::{Error, Result}; use crate::limits::LimitExceeded; -use std::time::{Duration, Instant}; +use crate::time::{Duration, Instant}; /// Default maximum AST depth (matches ExecutionLimits default) const DEFAULT_MAX_AST_DEPTH: usize = 100; diff --git a/crates/bashkit/src/scripted_tool/execute.rs b/crates/bashkit/src/scripted_tool/execute.rs index 6477cbb5e..6248d5408 100644 --- a/crates/bashkit/src/scripted_tool/execute.rs +++ b/crates/bashkit/src/scripted_tool/execute.rs @@ -226,7 +226,7 @@ impl Tool for ScriptedTool { let req = tool_request_from_value(self.locale(), args)?; let tool = self.clone(); Ok(ToolExecution::new(move |stream_sender| async move { - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let response = tool.run_request_with_stream(req, stream_sender).await; tool_output_from_response(response, start.elapsed()) })) @@ -400,7 +400,7 @@ mod tests { }, ) .build(); - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let resp = tool .execute(ToolRequest { @@ -429,7 +429,7 @@ mod tests { }, ) .build(); - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let resp = tool .execute_with_status( diff --git a/crates/bashkit/src/time.rs b/crates/bashkit/src/time.rs new file mode 100644 index 000000000..d774593e8 --- /dev/null +++ b/crates/bashkit/src/time.rs @@ -0,0 +1,32 @@ +//! Platform-compatible time types. +//! +//! On native targets this re-exports `std::time` directly. +//! On `wasm32-unknown-unknown` it uses `web_time` so that +//! `SystemTime::now()` and `Instant::now()` work in the browser instead +//! of panicking. + +#[cfg(target_family = "wasm")] +pub use web_time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +#[cfg(not(target_family = "wasm"))] +pub use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +/// Convert a [`chrono::DateTime`] to our platform-compatible [`SystemTime`]. +/// +/// `chrono`'s `From` impls only cover `std::time::SystemTime`, so this helper +/// bridges the gap on WASM where we use `web_time::SystemTime`. +pub fn from_chrono(dt: chrono::DateTime) -> SystemTime { + let secs = dt.timestamp(); + let nanos = dt.timestamp_subsec_nanos(); + UNIX_EPOCH + Duration::from_secs(secs as u64) + Duration::from_nanos(nanos as u64) +} + +/// Convert our platform-compatible [`SystemTime`] to a [`chrono::DateTime`]. +/// +/// `chrono`'s `From` impls only cover `std::time::SystemTime`, so this helper +/// bridges the gap on WASM where we use `web_time::SystemTime`. +pub fn to_chrono_utc(st: SystemTime) -> chrono::DateTime { + let duration = st.duration_since(UNIX_EPOCH).unwrap_or_default(); + chrono::DateTime::from_timestamp(duration.as_secs() as i64, duration.subsec_nanos()) + .unwrap_or(chrono::DateTime::UNIX_EPOCH) +} diff --git a/crates/bashkit/src/tool.rs b/crates/bashkit/src/tool.rs index 827e856eb..c1e01a18f 100644 --- a/crates/bashkit/src/tool.rs +++ b/crates/bashkit/src/tool.rs @@ -985,7 +985,7 @@ impl Tool for BashTool { let tool = self.clone(); Ok(ToolExecution::new(move |stream_sender| async move { - let start = std::time::Instant::now(); + let start = crate::time::Instant::now(); let response = tool.run_request_with_stream(req, stream_sender).await; tool_output_from_response(response, start.elapsed()) }))