diff --git a/Cargo.lock b/Cargo.lock index 9b3f3068..fa8da5a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -902,9 +902,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "bzip2" @@ -2813,6 +2813,7 @@ dependencies = [ "arc-swap", "bincode", "borsh 1.5.7", + "bytes", "clap", "crossbeam-channel", "dashmap", @@ -2823,6 +2824,7 @@ dependencies = [ "lazy_static", "libc", "log", + "mio", "prometheus", "prost 0.13.5", "prost-types 0.13.5", @@ -2936,9 +2938,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.176" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libloading" @@ -3202,13 +3204,14 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", + "log", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 13eeab80..952fb410 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,10 @@ members = ["examples", "jito_protos", "proxy"] resolver = "2" +[profile.debug-release] +inherits = "release" +debug = true + [workspace.package] version = "0.2.12-triton" description = "Fast path to receive shreds from Jito, forwarding to local consumers. See https://docs.jito.wtf/lowlatencytxnfeed/ for details." @@ -14,6 +18,7 @@ ahash = "0.8" arc-swap = "1.6" bincode = "1.3.3" borsh = "1.5.3" +bytes = "1.11.0" clap = { version = "4", features = ["derive", "env"] } crossbeam-channel = "0.5.8" dashmap = "5" @@ -24,6 +29,7 @@ jito-protos = { path = "jito_protos" } lazy_static = "1.4.0" libc = "0.2" log = "0.4" +mio = "1.1.1" prost = "0.13" prost-types = "0.13" prometheus = "0.14.0" diff --git a/data-sample.txt b/data-sample.txt new file mode 100644 index 00000000..568c811a --- /dev/null +++ b/data-sample.txt @@ -0,0 +1,75 @@ +set 1: + +shredstream_recv_interval_usec_bucket{le="1"} 41 + +shredstream_recv_interval_usec_bucket{le="5"} 63 + +shredstream_recv_interval_usec_bucket{le="10"} 125 + +shredstream_recv_interval_usec_bucket{le="25"} 47676 + +shredstream_recv_interval_usec_bucket{le="50"} 104430 + +shredstream_recv_interval_usec_bucket{le="100"} 162673 + +shredstream_recv_interval_usec_bucket{le="200"} 190777 + +shredstream_recv_interval_usec_bucket{le="500"} 205050 + +shredstream_recv_interval_usec_bucket{le="1000"} 210046 + +shredstream_recv_interval_usec_bucket{le="2000"} 212204 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 214080 + + + +set 2: + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 22 + +shredstream_recv_interval_usec_bucket{le="25"} 864700 + +shredstream_recv_interval_usec_bucket{le="50"} 1059516 + +shredstream_recv_interval_usec_bucket{le="100"} 1334130 + +shredstream_recv_interval_usec_bucket{le="200"} 1473381 + +shredstream_recv_interval_usec_bucket{le="500"} 1545124 + +shredstream_recv_interval_usec_bucket{le="1000"} 1569639 + +shredstream_recv_interval_usec_bucket{le="2000"} 1580383 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 1589948 + + + +set 3 : + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 2 + +shredstream_recv_interval_usec_bucket{le="25"} 129306 + +shredstream_recv_interval_usec_bucket{le="50"} 159982 + +shredstream_recv_interval_usec_bucket{le="100"} 202752 + +shredstream_recv_interval_usec_bucket{le="200"} 225469 + +shredstream_recv_interval_usec_bucket{le="500"} 238215 + +shredstream_recv_interval_usec_bucket{le="1000"} 242741 + +shredstream_recv_interval_usec_bucket{le="2000"} 244727 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 246249 \ No newline at end of file diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 5f8defe0..e7a9b353 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,11 +6,22 @@ authors = { workspace = true } homepage = { workspace = true } edition = { workspace = true } +[[bin]] +name = "triton-shredproxy" +path = "src/main2.rs" + +[[bin]] +name = "jito-shredstream-proxy" +path = "src/main.rs" + + + [dependencies] ahash = { workspace = true } arc-swap = { workspace = true } bincode = { workspace = true } borsh = { workspace = true } +bytes = { workspace = true } clap = { workspace = true } crossbeam-channel = { workspace = true } dashmap = { workspace = true } @@ -21,6 +32,7 @@ jito-protos = { workspace = true } lazy_static = { workspace = true } log = { workspace = true } libc = { workspace = true } +mio = { workspace = true } prometheus = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/proxy/src/forwarder.rs b/proxy/src/forwarder.rs index 92f9a508..6bc9ef1d 100644 --- a/proxy/src/forwarder.rs +++ b/proxy/src/forwarder.rs @@ -14,8 +14,8 @@ use crossbeam_channel::{Receiver, RecvError}; use dashmap::DashMap; use itertools::Itertools; use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; -use log::{debug, error, info, warn}; use libc; +use log::{debug, error, info, warn}; use prost::Message; use socket2::{Domain, Protocol, Socket, Type}; use solana_client::client_error::reqwest; @@ -35,15 +35,13 @@ use solana_streamer::{ use tokio::sync::broadcast::Sender; use crate::{ - ShredstreamProxyError, deshred::{self, ComparableShred, ShredsStateTracker}, prom::{ - observe_dedup_time, observe_send_packet_count, observe_send_duration, - observe_recv_interval, observe_recv_packet_count, - inc_packets_received, inc_packets_deduped, inc_packets_forwarded, - inc_packets_forward_failed, inc_packets_by_source, + inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, + inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, + observe_recv_packet_count, observe_send_duration, observe_send_packet_count, }, - resolve_hostname_port, + resolve_hostname_port, ShredstreamProxyError, }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -169,7 +167,6 @@ pub fn start_forwarder_threads( let reconstruct_tx = reconstruct_tx.clone(); let exit = exit.clone(); - let send_thread = Builder::new() .name(format!("ssPxyTx_{thread_id}")) .spawn(move || { @@ -259,8 +256,7 @@ pub fn start_forwarder_threads( /// /// Try to create an IPv6 UDP socket bound to the given address. -/// -fn try_create_ipv6_socket(addr: SocketAddr) -> Result { +pub fn try_create_ipv6_socket(addr: SocketAddr) -> Result { let ipv6_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; ipv6_socket.set_multicast_hops_v6(IP_MULTICAST_TTL)?; ipv6_socket.bind(&addr.into())?; @@ -280,16 +276,14 @@ fn recv_from_channel_and_send_multiple_dest( reconstruct_tx: &crossbeam_channel::Sender, debug_trace_shred: bool, metrics: &ShredMetrics, -) -> Result<(), ShredstreamProxyError> +) -> Result<(), ShredstreamProxyError> where F: Fn(IpAddr, SocketAddr) -> bool, { let packet_batch = maybe_packet_batch.map_err(ShredstreamProxyError::RecvError)?; let trace_shred_received_time = SystemTime::now(); let batch_len = packet_batch.len() as u64; - metrics - .received - .fetch_add(batch_len, Ordering::Relaxed); + metrics.received.fetch_add(batch_len, Ordering::Relaxed); inc_packets_received(batch_len); observe_recv_packet_count(batch_len as f64); debug!( @@ -310,7 +304,9 @@ where &mut packet_batch_vec, ); let t_dedup_usecs = t.elapsed().as_micros() as u64; - metrics.dedup_time_spent.fetch_add(t_dedup_usecs, Ordering::Relaxed); + metrics + .dedup_time_spent + .fetch_add(t_dedup_usecs, Ordering::Relaxed); observe_dedup_time(t_dedup_usecs as f64); inc_packets_deduped(num_deduped); @@ -326,10 +322,12 @@ where *discarded += is_discarded as u64; *not_discarded += (!is_discarded) as u64; }) - .or_insert_with(|| { - (is_discarded as u64, (!is_discarded) as u64) - }); - let status = if is_discarded { "discarded" } else { "forwarded" }; + .or_insert_with(|| (is_discarded as u64, (!is_discarded) as u64)); + let status = if is_discarded { + "discarded" + } else { + "forwarded" + }; inc_packets_by_source(&addr.to_string(), status, 1); }); }); @@ -358,10 +356,14 @@ where .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); metrics.send_batch_count.fetch_add(1, Ordering::Relaxed); const MAX_IOV: usize = libc::UIO_MAXIOV as usize; - let max_iov_count = packets_with_dest.len() / MAX_IOV; + let max_iov_count = packets_with_dest.len() / MAX_IOV; let unsaturated_iov_count = packets_with_dest.len() % MAX_IOV; - metrics.saturated_iov_count.fetch_add(max_iov_count as u64, Ordering::Relaxed); - metrics.unsaturated_iov_count.fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); + metrics + .saturated_iov_count + .fetch_add(max_iov_count as u64, Ordering::Relaxed); + metrics + .unsaturated_iov_count + .fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); observe_send_packet_count(packets_with_dest.len() as f64); match batch_send(send_socket, &packets_with_dest) { Ok(_) => { @@ -387,7 +389,9 @@ where } } let t_send_usecs = t.elapsed().as_micros() as u64; - metrics.batch_send_time_spent.fetch_add(t_send_usecs, Ordering::Relaxed); + metrics + .batch_send_time_spent + .fetch_add(t_send_usecs, Ordering::Relaxed); observe_send_duration(t_send_usecs as f64); }); @@ -645,7 +649,11 @@ impl ShredMetrics { datapoint_info!( "shredstream_proxy-sendmmsg_iov_metrics", - ("max_iov_count", self.saturated_iov_count.load(Ordering::Relaxed), i64), + ( + "max_iov_count", + self.saturated_iov_count.load(Ordering::Relaxed), + i64 + ), ( "unsaturated_iov_count", self.unsaturated_iov_count.load(Ordering::Relaxed), @@ -654,12 +662,16 @@ impl ShredMetrics { ); datapoint_info!( - "shredstream_proxy-batch_send_metrics", + "shredstream_proxy-batch_send_metrics", ( - "send_batch_size_sum", self.send_batch_size_sum.load(Ordering::Relaxed), i64 + "send_batch_size_sum", + self.send_batch_size_sum.load(Ordering::Relaxed), + i64 ), ( - "send_batch_count", self.send_batch_count.load(Ordering::Relaxed), i64 + "send_batch_count", + self.send_batch_count.load(Ordering::Relaxed), + i64 ) ); @@ -677,7 +689,6 @@ impl ShredMetrics { ), ); - if self.enabled_grpc_service { datapoint_info!( "shredstream_proxy-service_metrics", diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 520bba51..9ba7365d 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -31,6 +31,9 @@ mod multicast_config; mod server; mod token_authenticator; mod prom; +mod recv_mmsg; +mod mem; +mod triton_forwarder; #[cfg(not(target_env = "msvc"))] use tikv_jemallocator::Jemalloc; diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs new file mode 100644 index 00000000..62d53db1 --- /dev/null +++ b/proxy/src/main2.rs @@ -0,0 +1,494 @@ +use std::{ + collections::HashMap, io::{self, Error, ErrorKind}, net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs}, num::NonZeroUsize, panic, path::{Path, PathBuf}, str::FromStr, sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, thread::{self, sleep, spawn, JoinHandle}, time::Duration +}; + +use arc_swap::ArcSwap; +use clap::{arg, Parser}; +use crossbeam_channel::{Receiver, RecvError, Sender}; +use log::*; +use signal_hook::consts::{SIGINT, SIGTERM}; +use solana_client::client_error::{reqwest, ClientError}; +use solana_ledger::shred::Shred; +use solana_metrics::set_host_id; +use solana_sdk::{clock::Slot, signature::read_keypair_file}; +use solana_streamer::streamer::StreamerReceiveStats; +use thiserror::Error; +use tokio::runtime::Runtime; +use tonic::Status; + +use crate::{ + forwarder::ShredMetrics, multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_socket_on_device}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig +}; +pub mod deshred; +pub mod forwarder; +pub mod heartbeat; +pub mod multicast_config; +pub mod server; +pub mod token_authenticator; +pub mod prom; +pub mod recv_mmsg; +pub mod mem; +pub mod triton_forwarder; + +use triton_forwarder::{PktRecvMemSizing}; + +#[derive(Clone, Debug, Parser)] +#[clap(author, version, about, long_about = None)] +// https://docs.rs/clap/latest/clap/_derive/_cookbook/git_derive/index.html +struct Args { + #[command(subcommand)] + shredstream_args: ProxySubcommands, +} + +#[derive(Clone, Debug, clap::Subcommand)] +enum ProxySubcommands { + /// Requests shreds from Jito and sends to all destinations. + Shredstream(ShredstreamArgs), + + /// Does not request shreds from Jito. Sends anything received on `src-bind-addr`:`src-bind-port` to all destinations. + ForwardOnly(CommonArgs), +} + +#[derive(clap::Args, Clone, Debug)] +struct ShredstreamArgs { + /// Address for Jito Block Engine. + /// See https://jito-labs.gitbook.io/mev/searcher-resources/block-engine#connection-details + #[arg(long, env)] + block_engine_url: String, + + /// Manual override for auth service address. For internal use. + #[arg(long, env)] + auth_url: Option, + + /// Path to keypair file used to authenticate with the backend. + #[arg(long, env)] + auth_keypair: PathBuf, + + /// Desired regions to receive heartbeats from. + /// Receives `n` different streams. Requires at least 1 region, comma separated. + #[arg(long, env, value_delimiter = ',', required(true))] + desired_regions: Vec, + + #[clap(flatten)] + common_args: CommonArgs, +} + +#[derive(clap::Args, Clone, Debug)] +struct CommonArgs { + /// Address where Shredstream proxy listens. + #[arg(long, env, default_value_t = IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)))] + src_bind_addr: IpAddr, + + /// Port where Shredstream proxy listens. Use `0` for random ephemeral port. + #[arg(long, env, default_value_t = 20_000)] + src_bind_port: u16, + + /// Multicast IP to listen for shreds. If none provided, attempts to + /// parse multicast routes for the device specified by `--multicast-device` + /// via `ip --json route show dev `. + #[arg(long, env)] + multicast_bind_ip: Option, + + /// Network device to use for multicast route discovery and interface selection. + /// Example: `eth0`, `en0`, or `doublezero1`. + #[arg(long, env, default_value = "doublezero1")] + multicast_device: String, + + /// Port to receive multicast shreds + #[arg(long, env, default_value_t = 20001)] + multicast_subscribe_port: u16, + + /// Static set of IP:Port where Shredstream proxy forwards shreds to, comma separated. + /// Eg. `127.0.0.1:8001,10.0.0.1:8001`. + // Note: store the original string, so we can do hostname resolution when refreshing destinations + #[arg(long, env, value_delimiter = ',', value_parser = resolve_hostname_port)] + dest_ip_ports: Vec<(SocketAddr, String)>, + + /// Http JSON endpoint to dynamically get IPs for Shredstream proxy to forward shreds. + /// Endpoints are then set-union with `dest-ip-ports`. + #[arg(long, env)] + endpoint_discovery_url: Option, + + /// Port to send shreds to for hosts fetched via `endpoint-discovery-url`. + /// Port can be found using `scripts/get_tvu_port.sh`. + /// See https://jito-labs.gitbook.io/mev/searcher-services/shredstream#running-shredstream + #[arg(long, env)] + discovered_endpoints_port: Option, + + /// Interval between logging stats to stdout and influx + #[arg(long, env, default_value_t = 15_000)] + metrics_report_interval_ms: u64, + + /// Logs trace shreds to stdout and influx + #[arg(long, env, default_value_t = false)] + debug_trace_shred: bool, + + /// Public IP address to use. + /// Overrides value fetched from `ifconfig.me`. + #[arg(long, env)] + public_ip: Option, + + /// Number of threads to use. Defaults to use up to 4. + #[arg(long, env)] + num_threads: Option, + + /// + /// The multicast group (ip addr) to join for receiving shreds. + /// Multicast groups supports IPv4 and IPv6. + #[arg(long, env)] + triton_multicast_group: Option, + /// The interface to bind to for triton multicast. + /// If IPV6 is used, this argument must be provided. + /// If ipv4, then optional (listen on all interfaces if not provided). + #[arg(long, env)] + triton_multicast_bind_interface: Option, + + /// + /// The port to listen on for triton multicast. + /// If not provided, defaults to 8002. + /// NOTE: this port must match the port used by the triton multicast sender. + #[arg(long, env)] + triton_multicast_subscription_port: Option, + + /// Address to bind prometheus metrics server to. If not provided, prometheus server is disabled. + #[arg(long, env)] + prometheus_bind_addr: Option, + + /// Number of tiles dedicated to receiving packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_recv_tile: Option, + + /// Number of tiles dedicated to forwarding packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_fwd_tile: Option, + + /// Memory sizing for EACH packet receiver, uses t-shirt size convention (xs (default),s,m,l,xl,2xl,3xl,4xl,5xl). Each size increase double the memory, starting at 128MiB for x-small. + #[arg(long, env)] + pkt_recv_channel_memsize: Option, + + /// Use hugepage memory for pkt recv tiles shared memory. + #[arg(long, env, default_value_t = false)] + hugepage: bool, +} + +#[derive(Debug, Error)] +pub enum ShredstreamProxyError { + #[error("TonicError {0}")] + TonicError(#[from] tonic::transport::Error), + #[error("GrpcError {0}")] + GrpcError(#[from] Status), + #[error("ReqwestError {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("SerdeJsonError {0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("RpcError {0}")] + RpcError(#[from] ClientError), + #[error("BlockEngineConnectionError {0}")] + BlockEngineConnectionError(#[from] BlockEngineConnectionError), + #[error("RecvError {0}")] + RecvError(#[from] RecvError), + #[error("IoError {0}")] + IoError(#[from] io::Error), + #[error("Shutdown")] + Shutdown, +} + +fn resolve_hostname_port(hostname_port: &str) -> io::Result<(SocketAddr, String)> { + let socketaddr = hostname_port.to_socket_addrs()?.next().ok_or_else(|| { + Error::new( + ErrorKind::AddrNotAvailable, + format!("Could not find destination {hostname_port}"), + ) + })?; + + Ok((socketaddr, hostname_port.to_string())) +} + +/// Returns public-facing IPV4 address +pub fn get_public_ip() -> reqwest::Result { + info!("Requesting public ip from ifconfig.me..."); + let client = reqwest::blocking::Client::builder() + .local_address(IpAddr::V4(Ipv4Addr::UNSPECIFIED)) + .build()?; + let response = client.get("https://ifconfig.me/ip").send()?.text()?; + let public_ip = IpAddr::from_str(&response).unwrap(); + info!("Retrieved public ip: {public_ip:?}"); + + Ok(public_ip) +} + +// Creates a channel that gets a message every time `SIGINT` is signalled. +fn shutdown_notifier(exit: Arc) -> io::Result<(Sender<()>, Receiver<()>)> { + let (s, r) = crossbeam_channel::bounded(256); + let mut signals = signal_hook::iterator::Signals::new([SIGINT, SIGTERM])?; + + let s_thread = s.clone(); + thread::spawn(move || { + for _ in signals.forever() { + exit.store(true, Ordering::SeqCst); + // send shutdown signal multiple times since crossbeam doesn't have broadcast channels + // each thread will consume a shutdown signal + for _ in 0..256 { + if s_thread.send(()).is_err() { + break; + } + } + } + }); + + Ok((s, r)) +} + +pub type ReconstructedShredsMap = HashMap>>; +fn main() -> Result<(), ShredstreamProxyError> { + env_logger::builder().init(); + let prom_registry = prometheus::Registry::new(); + prom::register_metrics(&prom_registry); + let all_args: Args = Args::parse(); + let shredstream_args = all_args.shredstream_args.clone(); + // common args + let args = match all_args.shredstream_args { + ProxySubcommands::Shredstream(x) => x.common_args, + ProxySubcommands::ForwardOnly(x) => x, + }; + + + let num_pkt_recv_tiles = args.num_pkt_recv_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + + let num_pkt_fwd_tiles = args.num_pkt_fwd_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + + set_host_id(hostname::get()?.into_string().unwrap()); + if (args.endpoint_discovery_url.is_none() && args.discovered_endpoints_port.is_some()) + || (args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_none()) + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "Invalid arguments provided, dynamic endpoints requires both --endpoint-discovery-url and --discovered-endpoints-port."))); + } + if args.endpoint_discovery_url.is_none() + && args.discovered_endpoints_port.is_none() + && args.dest_ip_ports.is_empty() + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "No destinations found. You must provide values for --dest-ip-ports or --endpoint-discovery-url."))); + } + + let exit = Arc::new(AtomicBool::new(false)); + let (shutdown_sender, shutdown_receiver) = + shutdown_notifier(exit.clone()).expect("Failed to set up signal handler"); + let panic_hook = panic::take_hook(); + { + let exit = exit.clone(); + panic::set_hook(Box::new(move |panic_info| { + exit.store(true, Ordering::SeqCst); + let _ = shutdown_sender.send(()); + error!("exiting process"); + sleep(Duration::from_secs(1)); + // invoke the default handler and exit the process + panic_hook(panic_info); + })); + } + + let metrics = Arc::new(ShredMetrics::new(false)); + + + let mut thread_handles = vec![]; + if let ProxySubcommands::Shredstream(args) = shredstream_args { + let runtime = Runtime::new()?; + if args.desired_regions.len() > 2 { + warn!( + "Too many regions requested, only regions: {:?} will be used", + &args.desired_regions[..2] + ); + } + let heartbeat_hdl = + start_heartbeat(args, &exit, &shutdown_receiver, runtime, metrics.clone()); + thread_handles.push(heartbeat_hdl); + } + + // share sockets between refresh and forwarder thread + let unioned_dest_sockets = Arc::new(ArcSwap::from_pointee( + args.dest_ip_ports + .iter() + .map(|x| x.0) + .collect::>(), + )); + + let forward_stats = Arc::new(StreamerReceiveStats::new("shredstream_proxy-listen_thread")); + let use_discovery_service = + args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_some(); + + let maybe_dz_multicast_socket_vec = create_multicast_socket_on_device( + &args.multicast_device, + args.multicast_subscribe_port, + args.multicast_bind_ip, + ) + .inspect(|mcast_socket| info!("Multicast listeners found: {mcast_socket:?}.")); + + let maybe_triton_multicast_config = match args.triton_multicast_group { + Some(multicast_group) => { + let device_ifname = args.triton_multicast_bind_interface.clone().ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidInput, + "'triton-multicast-bind-interface' is required if 'triton-multicast-group' is set", + ) + })?; + let subscription_port = args.triton_multicast_subscription_port.ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidInput, + "'triton-multicast-subscription-port' is required if 'triton-multicast-group' is set", + ) + })?; + match multicast_group { + IpAddr::V4(ipv4) => { + Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { + multicast_ip: ipv4, + device_ifname: device_ifname, + subscription_port, + })) + } + IpAddr::V6(ipv6) => { + Some(TritonMulticastConfig::Ipv6(TritonMulticastConfigV6 { + multicast_ip: ipv6, + device_ifname: device_ifname, + subscription_port, + })) + } + } + } + None => None, + }; + + let pkt_recv_tile_mem_config = PktRecvTileMemConfig { + memory_size: args.pkt_recv_channel_memsize.unwrap_or_default(), + hugepage: args.hugepage, + ..Default::default() + }; + let proxy_th = { + let exit = Arc::clone(&exit); + let pkt_recv_stats = forward_stats.clone(); + let pkt_fwd_stats = metrics.clone(); + let unioned_dest_sockets = Arc::clone(&unioned_dest_sockets); + std::thread::Builder::new() + .name("tritonProxyMain".to_string()) + .spawn(move || { + triton_forwarder::run_proxy_system( + pkt_recv_tile_mem_config, + unioned_dest_sockets, + maybe_triton_multicast_config, + args.src_bind_addr, + args.src_bind_port, + num_pkt_recv_tiles, + num_pkt_fwd_tiles, + FECSetRoutingStrategy, + exit, + pkt_recv_stats, + pkt_fwd_stats, + maybe_dz_multicast_socket_vec.unwrap_or_default(), + ); + }) + .expect("tritonProxyMain") + }; + + thread_handles.push(proxy_th); + + let report_metrics_thread = { + let exit = exit.clone(); + spawn(move || { + while !exit.load(Ordering::Relaxed) { + sleep(Duration::from_secs(1)); + forward_stats.report(); + } + }) + }; + thread_handles.push(report_metrics_thread); + + let metrics_hdl = triton_forwarder::start_forwarder_accessory_thread( + metrics.clone(), + args.metrics_report_interval_ms, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(metrics_hdl); + if use_discovery_service { + let refresh_handle = forwarder::start_destination_refresh_thread( + args.endpoint_discovery_url.unwrap(), + args.discovered_endpoints_port.unwrap(), + args.dest_ip_ports, + unioned_dest_sockets, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(refresh_handle); + } + + if let Some(prom_bind_addr) = args.prometheus_bind_addr { + let prom_hdl = prom::spawn_prometheus_server( + prom_bind_addr, + prom_registry, + shutdown_receiver.clone() + ); + thread_handles.push(prom_hdl); + } + + info!( + "Shredstream started, listening on {}:{}/udp.", + args.src_bind_addr, args.src_bind_port + ); + + + + for thread in thread_handles { + thread.join().expect("thread panicked"); + } + + info!( + "Exiting Shredstream, {} received , {} sent successfully, {} failed, {} duplicate shreds.", + metrics.agg_received_cumulative.load(Ordering::Relaxed), + metrics + .agg_success_forward_cumulative + .load(Ordering::Relaxed), + metrics.agg_fail_forward_cumulative.load(Ordering::Relaxed), + metrics.duplicate_cumulative.load(Ordering::Relaxed), + ); + Ok(()) +} + +fn start_heartbeat( + args: ShredstreamArgs, + exit: &Arc, + shutdown_receiver: &Receiver<()>, + runtime: Runtime, + metrics: Arc, +) -> JoinHandle<()> { + let auth_keypair = Arc::new( + read_keypair_file(Path::new(&args.auth_keypair)).unwrap_or_else(|e| { + panic!( + "Unable to parse keypair file. Ensure that file {:?} is readable. Error: {e}", + args.auth_keypair + ) + }), + ); + + heartbeat::heartbeat_loop_thread( + args.block_engine_url.clone(), + args.auth_url.unwrap_or(args.block_engine_url), + auth_keypair, + args.desired_regions, + SocketAddr::new( + args.common_args + .public_ip + .unwrap_or_else(|| get_public_ip().unwrap()), + args.common_args.src_bind_port, + ), + runtime, + "shredstream_proxy".to_string(), + metrics, + shutdown_receiver.clone(), + exit.clone(), + ) +} diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs new file mode 100644 index 00000000..65d060c2 --- /dev/null +++ b/proxy/src/mem.rs @@ -0,0 +1,839 @@ +use std::{ + hint::spin_loop, + sync::{ + Arc, atomic::{AtomicI32, AtomicUsize, Ordering} + }, time::Duration, cell::UnsafeCell, +}; + +use bytes::{buf::UninitSlice, Buf, BufMut}; + +#[derive(Debug, thiserror::Error)] +#[error("allocation error")] +pub struct AllocError; + +#[repr(C)] +pub struct SharedMem { + pub ptr: *mut u8, + len: usize, + closed: bool +} + +pub fn try_alloc_shared_mem( + num_items: usize, + capacity: usize, + huge: bool, +) -> Result<*mut u8, AllocError> { + // assert!(align.is_power_of_two(), "alignment must be a power of two"); + assert!( + capacity.is_power_of_two(), + "capacity must be a power of two" + ); + let total_len = capacity * num_items; + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + total_len, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED | libc::MAP_ANONYMOUS | if huge { libc::MAP_HUGETLB } else { 0 }, + -1, + 0, + ) + }; + + if std::ptr::eq(ptr, libc::MAP_FAILED) { + return Err(AllocError); + } + + // zero initialize the memory + unsafe { + std::ptr::write_bytes(ptr as *mut u8, 0, total_len); + } + + Ok(ptr as *mut u8) +} + +impl SharedMem { + pub fn new(element_size: usize, capacity: usize, huge: bool) -> Result { + let ptr = try_alloc_shared_mem(element_size, capacity, huge)?; + let len = capacity * element_size; + + Ok(Self { ptr, len, closed: false }) + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn dealloc(mut self) { + if self.closed { + return; + } + self.closed = true; + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.len); + } + } +} + +impl Drop for SharedMem { + fn drop(&mut self) { + if self.closed { + return; + } + self.closed = true; + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.len); + } + } +} + +#[derive(Debug)] +#[repr(C, align(16))] +pub struct FrameDesc { + pub ptr: *mut u8, + pub frame_size: usize, + pub shmem_idx: usize, +} + +unsafe impl Send for FrameDesc {} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBufMut { + ptr: *mut u8, + desc: FrameDesc, +} + +unsafe impl Send for FrameBufMut {} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBuf { + curr_ptr: *mut u8, + len: usize, + desc: FrameDesc, +} + +impl FrameBuf { + #[inline] + pub fn len(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } + + #[inline] + pub fn into_inner(self) -> FrameDesc { + self.desc + } + + #[inline] + pub unsafe fn detach_desc(&self) -> FrameDesc { + FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + shmem_idx: self.desc.shmem_idx, + } + } + + pub unsafe fn unsafe_clone(&self) -> Self { + Self { + curr_ptr: self.curr_ptr, + len: self.len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + shmem_idx: self.desc.shmem_idx, + }, + } + } + + pub unsafe fn unsafe_subslice_clone(&self, offset: usize, len: usize) -> Self { + assert!(offset + len <= self.len()); + Self { + curr_ptr: self.curr_ptr.add(offset), + len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + shmem_idx: self.desc.shmem_idx, + }, + } + } +} + +impl AsRef<[u8]> for FrameBuf { + fn as_ref(&self) -> &[u8] { + self.chunk() + } +} + +unsafe impl Send for FrameBuf {} + +impl From for FrameBuf { + fn from(buf_mut: FrameBufMut) -> Self { + let len = (buf_mut.ptr as usize) + .checked_sub(buf_mut.desc.ptr as usize) + .expect("FrameBufMut pointer underflow"); + assert!( + len <= buf_mut.desc.frame_size, + "FrameBufMut pointer out of bounds" + ); + Self { + curr_ptr: buf_mut.desc.ptr, + len, + desc: buf_mut.desc, + } + } +} + +impl FrameDesc { + pub fn as_mut_buf(&self) -> FrameBufMut { + FrameBufMut { + ptr: self.ptr, + desc: FrameDesc { + ptr: self.ptr, + frame_size: self.frame_size, + shmem_idx: self.shmem_idx, + }, + } + } +} + +impl From for FrameBufMut { + fn from(desc: FrameDesc) -> Self { + Self { ptr: desc.ptr, desc } + } +} + +impl FrameBufMut { + #[inline] + pub fn base(&self) -> *mut u8 { + self.desc.ptr + } + + #[inline] + pub fn capacity(&self) -> usize { + self.desc.frame_size + } + + + #[inline] + fn end_ptr(&self) -> *const u8 { + unsafe { self.desc.ptr.add(self.capacity()) } + } + + #[inline] + fn frame_offset(&self) -> usize { + let frame_offset = (self.ptr as usize) + .checked_sub(self.desc.ptr as usize) + .expect("FrameBufMut pointer underflow"); + assert!( + frame_offset <= self.desc.frame_size, + "FrameBufMut pointer out of bounds" + ); + frame_offset + } + + #[inline] + pub unsafe fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + #[inline] + pub unsafe fn seek(&mut self, offset: usize) { + assert!(offset <= self.desc.frame_size, "seek offset out of bounds"); + let new_ptr = self.desc.ptr.add(offset); + let end_ptr = self.end_ptr(); + assert!(new_ptr as *const u8 <= end_ptr, "seek out of bounds"); + self.ptr = new_ptr; + } +} + +unsafe impl BufMut for FrameBufMut { + fn remaining_mut(&self) -> usize { + self.desc.frame_size - self.frame_offset() + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + let new_ptr = self.ptr.add(cnt); + assert!( + new_ptr as *const u8 <= self.end_ptr(), + "advance_mut out of bounds" + ); + self.ptr = new_ptr; + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + unsafe { UninitSlice::from_raw_parts_mut(self.ptr, self.remaining_mut()) } + } +} + +impl Buf for FrameBuf { + fn remaining(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } + + fn chunk(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.curr_ptr, self.remaining()) } + } + + fn advance(&mut self, cnt: usize) { + let new_ptr = unsafe { self.curr_ptr.add(cnt) }; + let end = unsafe { self.desc.ptr.add(self.len) }; + assert!(new_ptr as *const u8 <= end, "advance out of bounds"); + self.curr_ptr = new_ptr; + } +} + +use std::{ptr, sync::atomic::AtomicBool}; + +// We wrap T to include a 'ready' flag for each slot +#[repr(C)] +struct Slot { + data: UnsafeCell>, + is_ready: AtomicBool, +} + +struct RingInner { + buf: *mut Slot, // Changed to Slot + capacity: usize, + mask: usize, + head: AtomicUsize, // Producer index (reserved) + tail: AtomicUsize, // Consumer index + futex_flag: AtomicI32, + shmem: Option, +} + +impl Drop for RingInner { + fn drop(&mut self) { + if let Some(shmem) = self.shmem.take() { + let mut tail = self.tail.load(Ordering::Acquire); + let head = self.head.load(Ordering::Acquire); + + // Drop initialized slots + while tail != head { + unsafe { + let slot = &mut *self.buf.add(tail & self.mask); + if slot.is_ready.load(Ordering::Acquire) { + ptr::drop_in_place((*slot.data.get()).as_mut_ptr()); + } + } + tail = tail.wrapping_add(1); + } + + drop(shmem); + } + } +} + +unsafe impl Send for RingInner {} +unsafe impl Sync for RingInner {} + +pub struct Tx { + inner: Arc>, +} + +impl Clone for Tx { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +pub struct Rx { + inner: Arc>, +} + +pub fn message_ring(capacity: usize) -> Result<(Tx, Rx), AllocError> { + let capacity = capacity.next_power_of_two(); + let size = std::mem::size_of::>(); + + // Allocate memory for Slots + let shmem = SharedMem::new(size, capacity, false)?; + let ptr = shmem.ptr as *mut Slot; + // Initialize the is_ready flags to false + for i in 0..capacity { + unsafe { + let slot_ptr = ptr.add(i); + ptr::write(&mut (*slot_ptr).is_ready, AtomicBool::new(false)); + } + } + + let inner = Arc::new(RingInner { + buf: ptr, + capacity, + mask: capacity - 1, + head: AtomicUsize::new(0), + tail: AtomicUsize::new(0), + futex_flag: AtomicI32::new(0), + shmem: Some(shmem), + }); + + Ok(( + Tx { + inner: Arc::clone(&inner), + }, + Rx { inner }, + )) +} + +impl Tx { + pub fn send(&self, value: T) -> Result<(), T> { + loop { + // 1. Load head and tail to check if the ring is full. + // head: Relaxed is okay here because it's only a hint for the CAS. + // tail: Acquire is REQUIRED to ensure we don't overwrite data + // the consumer hasn't finished reading yet. + let head = self.inner.head.load(Ordering::Relaxed); + let tail = self.inner.tail.load(Ordering::Acquire); + + // 2. Calculate occupancy with wrapping arithmetic. + let used = head.wrapping_sub(tail); + if used == self.inner.capacity { + return Err(value); + } + if used > self.inner.capacity { + // Inconsistent snapshot under concurrent updates; retry. + std::hint::spin_loop(); + continue; + } + + // 2. Claim a slot using Compare-and-Swap (CAS). + // We use SeqCst or AcqRel here to ensure that once we "win" this slot, + // we have a synchronized view of the memory. + if self + .inner + .head + .compare_exchange_weak( + head, + head.wrapping_add(1), + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_ok() + { + unsafe { + // 3. Calculate slot location. + let slot = &*self.inner.buf.add(head & self.inner.mask); + + // 4. Write the data into the MaybeUninit. + // We use .write() which is a wrapper for ptr::write. + ptr::write((*slot.data.get()).as_mut_ptr(), value); + + // 5. RELEASE the data to the consumer. + // This store ensures the data write above is visible to + // any thread that performs an Acquire load on is_ready. + slot.is_ready.store(true, Ordering::Release); + } + + // 6. Futex Wake Logic. + // If the consumer is sleeping (futex_flag == 0), we wake them. + // We use Release to ensure the flag update is visible. + if self.inner.futex_flag.swap(1, Ordering::Release) == 0 { + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAKE, + 1, // Wake 1 thread + ); + } + } + return Ok(()); + } + // If CAS failed, another producer grabbed 'head'. + // The loop will retry with the new head value. + std::hint::spin_loop(); + } + } +} + +impl Rx { + pub fn recv(&mut self) -> T { + self.recv_timeout_inner(None).expect("recv failed") + } + + pub fn recv_timeout(&mut self, duration: Duration) -> Option { + self.recv_timeout_inner(Some(duration)) + } + + fn recv_timeout_inner(&mut self, duration: Option) -> Option { + for _ in 0..999 { + if let Some(val) = self.try_recv() { + return Some(val); + } + spin_loop(); + } + + loop { + if let Some(val) = self.try_recv() { + return Some(val); + } + + self.inner.futex_flag.store(0, Ordering::SeqCst); + + if let Some(val) = self.try_recv() { + return Some(val); + } + + let timespec: Option = duration.map(|d| libc::timespec { + tv_sec: d.as_secs() as libc::time_t, + tv_nsec: d.subsec_nanos() as libc::c_long, + }); + + let timeout_ptr = match ×pec { + Some(ts) => ts as *const libc::timespec, + None => std::ptr::null(), + }; + + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAIT, + 0, + timeout_ptr, + ); + } + + if duration.is_some(){ + return self.try_recv(); + } + } + } + + pub fn try_recv(&mut self) -> Option { + let tail = self.inner.tail.load(Ordering::Relaxed); + + unsafe { + let slot = &*self.inner.buf.add(tail & self.inner.mask); + + // IMPORTANT: In MPSC, even if head > tail, the data at tail might + // not be written yet because the producer was interrupted. + if !slot.is_ready.load(Ordering::Acquire) { + return None; + } + + let val = ptr::read((*slot.data.get()).as_ptr()); + + // Reset the flag for the next time this slot is used + slot.is_ready.store(false, Ordering::Release); + + // Increment tail to free the slot + self.inner.tail.store(tail.wrapping_add(1), Ordering::Release); + Some(val) + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::HashSet, + sync::{Arc as StdArc, Barrier, atomic::AtomicUsize as StdAtomicUsize}, + thread, + }; + + use super::*; + + #[test] + fn test_mpsc_contention() { + let capacity = 1024; + let (tx, mut rx) = message_ring::(capacity).unwrap(); + + let num_producers = 4; + let msgs_per_producer = 1000; + let barrier = Arc::new(Barrier::new(num_producers + 1)); + let mut handles = Vec::new(); + + // Start Producers + for p in 0..num_producers { + let tx_clone = tx.clone(); + let b_clone = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + b_clone.wait(); // Synchronize start + for i in 0..msgs_per_producer { + let val = p * 10000 + i; + while tx_clone.send(val).is_err() { + spin_loop(); // Wait if ring is full + } + } + })); + } + + barrier.wait(); // Start everyone at once + + let mut received = HashSet::new(); + let total_expected = num_producers * msgs_per_producer; + + for _ in 0..total_expected { + received.insert(rx.recv()); + } + + assert_eq!(received.len(), total_expected); + for h in handles { + h.join().unwrap(); + } + } + + #[test] + fn test_frame_buffer_lifecycle() { + let frame_size = 4096; + let capacity = 1; + // 1. Setup the memory pool + let mem = SharedMem::new(frame_size, capacity, false).unwrap(); + + // At this point, the fill_ring inside PagedAlignedMem logic + // should have been populated. Let's create our own handles for testing. + let (tx_fill, mut rx_fill) = message_ring::(capacity).unwrap(); + let (rx_tx, mut rx_rx) = message_ring::(capacity).unwrap(); + // Manually push frames into our test fill ring + for i in 0..capacity { + tx_fill + .send(FrameDesc { + ptr: unsafe { mem.ptr.add(i * frame_size) }, + frame_size, + shmem_idx: 0, + }) + .unwrap(); + } + + // 2. Simulate taking a frame from the pool + let desc = rx_fill.recv(); + let expected_ptr = desc.ptr; + println!("Received frame at ptr: {:p}", expected_ptr); + let mut buf = desc.as_mut_buf(); + assert_eq!(buf.remaining_mut(), 4096); + buf.put_u32(0xDEADBEEF); + assert_eq!(buf.remaining_mut(), 4092); + + rx_tx.send(desc).unwrap(); + + // 3. Verify the frame returned to the fill ring + let returned_desc = rx_rx.recv(); + assert_eq!(returned_desc.ptr, expected_ptr); + // 4. Verify the frame is zeroed out + } + + #[test] + fn test_blocking_recv() { + let (tx, mut rx) = message_ring::(16).unwrap(); + + let handle = thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(200)); + tx.send(42).unwrap(); + }); + + let start = std::time::Instant::now(); + let val = rx.recv(); // Should block for ~200ms + + assert_eq!(val, 42); + assert!(start.elapsed().as_millis() >= 200); + handle.join().unwrap(); + } + + #[test] + fn test_buf_and_bufmut_impls() { + let frame_size = 4096; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + + let mut buf_mut: FrameBufMut = desc.into(); + assert_eq!(buf_mut.remaining_mut(), 4096); + buf_mut.put_slice(&[1, 2, 3, 4]); + assert_eq!(buf_mut.remaining_mut(), 4092); + assert_eq!(buf_mut.chunk_mut().len(), 4092); + + let mut buf: FrameBuf = buf_mut.into(); + assert_eq!(buf.len(), 4); + assert_eq!(buf.remaining(), 4); + let chunk = buf.chunk(); + assert_eq!(chunk, &[1, 2, 3, 4]); + buf.advance(4); + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.len(), 0) + } + + #[test] + fn test_bufmut_non_power_of_two_frame_size() { + let frame_size = 1500; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + + let mut buf_mut: FrameBufMut = desc.into(); + assert_eq!(buf_mut.base(), shmem.ptr); + assert_eq!(buf_mut.remaining_mut(), frame_size); + + buf_mut.put_slice(&[1, 2, 3, 4, 5]); + assert_eq!(buf_mut.remaining_mut(), frame_size - 5); + + unsafe { buf_mut.seek(frame_size) }; + assert_eq!(buf_mut.remaining_mut(), 0); + } + + #[test] + fn test_send_returns_err_when_ring_is_full() { + let (tx, _rx) = message_ring::(1).unwrap(); + assert!(tx.send(1).is_ok()); + assert_eq!(tx.send(2), Err(2)); + } + + #[test] + fn test_recv_timeout_returns_none() { + let (_tx, mut rx) = message_ring::(8).unwrap(); + let start = std::time::Instant::now(); + let out = rx.recv_timeout(Duration::from_millis(30)); + assert_eq!(out, None); + assert!(start.elapsed() >= Duration::from_millis(25)); + } + + #[test] + fn test_recv_timeout_returns_value_before_deadline() { + let (tx, mut rx) = message_ring::(8).unwrap(); + let handle = thread::spawn(move || { + thread::sleep(Duration::from_millis(10)); + tx.send(99).unwrap(); + }); + + let out = rx.recv_timeout(Duration::from_millis(200)); + assert_eq!(out, Some(99)); + handle.join().unwrap(); + } + + #[test] + fn test_message_ring_zero_capacity_behaves_as_one() { + let (tx, mut rx) = message_ring::(0).unwrap(); + assert!(tx.send(1).is_ok()); + assert_eq!(tx.send(2), Err(2)); + assert_eq!(rx.recv(), 1); + } + + #[test] + #[should_panic(expected = "seek offset out of bounds")] + fn test_seek_panics_when_offset_exceeds_frame_size() { + let frame_size = 1500; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + let mut buf_mut: FrameBufMut = desc.into(); + unsafe { buf_mut.seek(frame_size + 1) }; + } + + #[test] + #[should_panic(expected = "advance_mut out of bounds")] + fn test_advance_mut_panics_when_exceeding_capacity() { + let frame_size = 256; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + let mut buf_mut: FrameBufMut = desc.into(); + unsafe { buf_mut.advance_mut(frame_size + 1) }; + } + + #[test] + #[should_panic(expected = "advance out of bounds")] + fn test_buf_advance_panics_when_exceeding_len() { + let frame_size = 256; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + + let mut buf_mut: FrameBufMut = desc.into(); + buf_mut.put_slice(&[1, 2, 3, 4]); + let mut buf: FrameBuf = buf_mut.into(); + buf.advance(5); + } + + #[test] + fn test_send_and_recv_work_across_counter_rollover() { + let (tx, mut rx) = message_ring::(8).unwrap(); + tx.inner.head.store(usize::MAX, Ordering::Relaxed); + tx.inner.tail.store(usize::MAX, Ordering::Relaxed); + + assert!(tx.send(1).is_ok()); + assert_eq!(rx.recv(), 1); + assert_eq!(tx.inner.head.load(Ordering::Relaxed), 0); + assert_eq!(tx.inner.tail.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_send_and_recv_with_backlog_across_counter_rollover() { + let (tx, mut rx) = message_ring::(2).unwrap(); + tx.inner.head.store(usize::MAX, Ordering::Relaxed); + tx.inner.tail.store(usize::MAX, Ordering::Relaxed); + + assert!(tx.send(10).is_ok()); // head wraps to 0 + assert!(tx.send(11).is_ok()); // queue full now + assert_eq!(tx.send(12), Err(12)); + assert_eq!(tx.inner.head.load(Ordering::Relaxed), 1); + + assert_eq!(rx.recv(), 10); + assert_eq!(rx.recv(), 11); + assert_eq!(tx.inner.tail.load(Ordering::Relaxed), 1); + + assert!(tx.send(13).is_ok()); + assert_eq!(rx.recv(), 13); + } + + #[test] + fn test_send_retries_when_used_exceeds_capacity_snapshot() { + let (tx, mut rx) = message_ring::(8).unwrap(); + tx.inner.head.store(5, Ordering::Relaxed); + tx.inner.tail.store(10, Ordering::Relaxed); // used = 5.wrapping_sub(10) > capacity + + let tx_sender = tx.clone(); + let (done_tx, done_rx) = std::sync::mpsc::channel(); + let sender = thread::spawn(move || { + let out = tx_sender.send(42); + done_tx.send(out).unwrap(); + }); + + thread::sleep(Duration::from_millis(10)); + tx.inner.tail.store(5, Ordering::Release); // restore consistent snapshot + + let send_result = done_rx + .recv_timeout(Duration::from_millis(200)) + .expect("send should complete after snapshot becomes consistent"); + assert!(send_result.is_ok()); + assert_eq!(rx.recv(), 42); + sender.join().unwrap(); + } + + #[test] + fn test_drop_drops_pending_items_once() { + struct DropCounter(StdArc); + impl Drop for DropCounter { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::Relaxed); + } + } + + let dropped = StdArc::new(StdAtomicUsize::new(0)); + { + let (tx, _rx) = message_ring::(8).unwrap(); + for _ in 0..3 { + assert!(tx.send(DropCounter(StdArc::clone(&dropped))).is_ok()); + } + } + assert_eq!(dropped.load(Ordering::Relaxed), 3); + } +} diff --git a/proxy/src/prom.rs b/proxy/src/prom.rs index 32f88e06..c53f3b7d 100644 --- a/proxy/src/prom.rs +++ b/proxy/src/prom.rs @@ -26,7 +26,7 @@ lazy_static::lazy_static! { static ref RECV_PACKET_COUNT_HIST: Histogram = Histogram::with_opts( HistogramOpts::new("shredstream_recv_packet_count", "Number of packets in incoming batch (before dedup)") - .buckets(vec![1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 30.0, 40.0, 50.0, 64.0, 100.0]) + .buckets(vec![1.0, 5.0, 10.0, 20.0, 50.0, 64.0]) ).unwrap(); static ref PACKETS_RECEIVED_TOTAL: Counter = Counter::new( @@ -49,6 +49,26 @@ lazy_static::lazy_static! { Opts::new("shredstream_packets_by_source", "Packets per source IP"), &["addr", "status"] ).unwrap(); + + static ref ROUTING_DROP: Counter = Counter::new( + "shredstream_routing_drop_total", "Packets dropped due to routing issues" + ).unwrap(); + + static ref ROUTING_SEND: IntCounterVec = IntCounterVec::new( + Opts::new( + + "shredstream_routing_send_total", "Packets successfully routed to send queue" + ), + &["queue"] + ).unwrap(); +} + +pub fn inc_routing_drop() { + ROUTING_DROP.inc(); +} + +pub fn inc_routing_send>(queue_label: S) { + ROUTING_SEND.with_label_values(&[queue_label.as_ref()]).inc(); } pub fn observe_dedup_time(microseconds: f64) { @@ -102,6 +122,8 @@ pub fn register_metrics(registry: &prometheus::Registry) { registry.register(Box::new(PACKETS_FORWARDED_TOTAL.clone())).unwrap(); registry.register(Box::new(PACKETS_FORWARD_FAILED_TOTAL.clone())).unwrap(); registry.register(Box::new(PACKETS_BY_SOURCE.clone())).unwrap(); + registry.register(Box::new(ROUTING_DROP.clone())).unwrap(); + registry.register(Box::new(ROUTING_SEND.clone())).unwrap(); } diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs new file mode 100644 index 00000000..7a81e8eb --- /dev/null +++ b/proxy/src/recv_mmsg.rs @@ -0,0 +1,686 @@ +use std::{ + cmp, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, num, os::fd::AsRawFd, sync::{atomic::{AtomicBool, Ordering}, Arc, Mutex}, time::{Duration, Instant} +}; + +use bytes::{Buf, BufMut}; +use itertools::izip; +use libc::{AF_INET, AF_INET6, MSG_DONTWAIT, iovec, mmsghdr, msghdr, sockaddr_storage}; +use log::error; +use mio::{Poll, Token, Waker}; +use socket2::socklen_t; +use solana_perf::packet::PACKETS_PER_BATCH; +use solana_sdk::packet::{Meta, PACKET_DATA_SIZE}; +use solana_streamer::{streamer::StreamerReceiveStats}; + +use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_received, inc_routing_drop, inc_routing_send, observe_recv_packet_count}}; + +pub trait PacketRoutingStrategy: Clone { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; +} + +#[inline] +fn hash_pair(x: u64, y: u32) -> u64 { + let mut h = x ^ ((y as u64) << 32); + h ^= h >> 33; + h = h.wrapping_mul(0xff51afd7ed558ccd); + h ^= h >> 33; + h +} + +#[derive(Debug, Clone)] +pub struct FECSetRoutingStrategy; + +impl PacketRoutingStrategy for FECSetRoutingStrategy { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option { + if num_dest == 0 { + return None; + } + let shred_buf = packet.buffer.chunk(); + let slot = solana_ledger::shred::wire::get_slot(shred_buf)?; + let fec = shred_buf.get(79..79 + 4)?; + let fec_bytes: [u8; 4] = fec.try_into().ok()?; + let fec_set_index = u32::from_le_bytes(fec_bytes); + let hash = hash_pair(slot, fec_set_index); + let dest = (hash as usize) % num_dest; + Some(dest) + } +} + +pub fn recv_loop( + sk_vec: Vec, + exit: &AtomicBool, + stats: &StreamerReceiveStats, + fill_rx: &mut Rx, + packet_tx_vec: &[Tx], + wake_slot: Arc>>>, + router: R, +) -> std::io::Result<()> +where + R: PacketRoutingStrategy, +{ + assert!(packet_tx_vec.len() > 0, "packet_tx_vec must have at least one destination"); + let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); + let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); + let mut next_stats_report = Instant::now() + Duration::from_secs(1); + let mut router_dest_dist = vec![0usize; packet_tx_vec.len()]; + let mut router_dest_label_vec = Vec::with_capacity(packet_tx_vec.len()); + for idx in 0..packet_tx_vec.len() { + router_dest_label_vec.push(idx.to_string()); + } + let mut poll = Poll::new()?; + const WAKE_TOKEN: Token = Token(usize::MAX); + let mut events = mio::Events::with_capacity(sk_vec.len() + 1); + let wake_handle = Arc::new(Waker::new(poll.registry(), WAKE_TOKEN)?); + *wake_slot.lock().expect("recv wake slot lock poisoned") = Some(Arc::clone(&wake_handle)); + let mut empty_fill_backoff = 0u32; + + let mut mio_sockets: Vec = sk_vec + .iter() + .map(|sk| mio::net::UdpSocket::from_std(sk.try_clone().unwrap())) + .collect(); + // Initial registration of sockets + for (i, socket) in mio_sockets.iter_mut().enumerate() { + poll.registry().register( + socket, + mio::Token(i), + mio::Interest::READABLE, + )?; + } + + while !exit.load(Ordering::Relaxed) { + + // Events are always cleared before receiving new ones + let result = poll.poll(&mut events, None); + + match result { + Ok(_) => { } + Err(e) => { + if e.kind() != io::ErrorKind::TimedOut { + return Err(e); + } + } + } + + if next_stats_report.elapsed() > Duration::ZERO { + next_stats_report = Instant::now() + Duration::from_secs(1); + log::trace!( + "recv_loop: packets_count={}, packet_batches_count={}, full_packet_batches_count={}", + stats.packets_count.load(Ordering::Relaxed), + stats.packet_batches_count.load(Ordering::Relaxed), + stats.full_packet_batches_count.load(Ordering::Relaxed), + ); + } + // Check for exit signal, even if socket is busy + // (for instance the leader transaction socket) + if exit.load(Ordering::Relaxed) { + return Ok(()); + } + // We can't use a for-loop here because we need to be able to drain the readiness of each socket. + // Since each recv_from is bounded by a PACKETS_PER_BATCH, we may need to call recv_from multiple times per socket + // until we get a WouldBlock error. + let mut ev_iter = events.iter(); + let Some(mut ev) = ev_iter.next() else { + continue; + }; + 'drain_readiness_loop: while !exit.load(Ordering::Relaxed) { + if ev.token() == WAKE_TOKEN { + if exit.load(Ordering::Relaxed) { + return Ok(()); + } + match ev_iter.next() { + Some(next_ev) => { + ev = next_ev; + continue 'drain_readiness_loop; + } + None => break 'drain_readiness_loop, + } + } + + let sk_idx = ev.token().0; + let recv_sk = &sk_vec[sk_idx]; + + // Refill the frame buffers as much as we can, + 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { + let maybe_frame_buf = fill_rx.try_recv(); + match maybe_frame_buf { + Some(frame_desc) => { + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + empty_fill_backoff = 0; + } + None => { + if frame_bufmut_vec.is_empty() { + if exit.load(Ordering::Relaxed) { + break 'fill_bufmut; + } + empty_fill_backoff = empty_fill_backoff.saturating_add(1); + if empty_fill_backoff <= 128 { + std::hint::spin_loop(); + } else { + std::thread::yield_now(); + } + break 'fill_bufmut; + } else { + break 'fill_bufmut; + } + } + } + } + + if frame_bufmut_vec.is_empty() { + // No available frame buffers to receive into, wait a bit + log::debug!("recv_loop: no available frame buffers to receive into"); + continue 'drain_readiness_loop; + } + + let result = recv_from(&mut frame_bufmut_vec, recv_sk, &mut packet_batch, &exit); + + match result { + Ok(len) => { + if len > 0 { + // observe_recv_interval(recv_interval.as_micros() as f64); + inc_packets_received(len as u64); + observe_recv_packet_count(len as f64); + let StreamerReceiveStats { + packets_count, + packet_batches_count, + full_packet_batches_count, + .. + } = stats; + + packets_count.fetch_add(len, Ordering::Relaxed); + packet_batches_count.fetch_add(1, Ordering::Relaxed); + if len == PACKETS_PER_BATCH { + full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + } + packet_batch + .iter_mut() + .for_each(|p| p.meta_mut().set_from_staked_node(false)); + + 'packet_drain: for packet in packet_batch.drain(..) { + let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { + Some(idx) => idx, + None => { + log::trace!("Failed to route packet {:?}", packet); + let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); + frame_bufmut_vec.push(trashed_frame_bufmut); + inc_routing_drop(); + continue 'packet_drain; + } + }; + router_dest_dist[dest_idx] += 1; + let _ = &packet_tx_vec[dest_idx] + .send(packet) + .unwrap_or_else(|_packet| panic!("failed to send packet to {dest_idx} ring is full, distr:{:?}", router_dest_dist)); + inc_routing_send(&router_dest_label_vec[dest_idx]); + } + } + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + // Only when we drained all events for this poll iteration, we process the next event or break + match ev_iter.next() { + Some(next_ev) => { + ev = next_ev; + continue 'drain_readiness_loop; + } + None => { + break 'drain_readiness_loop; + } + } + } else { + return Err(e); + } + } + } + } + } + Ok(()) +} + +pub fn recv_from( + available_frame_buf_vec: &mut Vec, + socket: &UdpSocket, + batch: &mut Vec, + exit: &AtomicBool, +) -> std::io::Result { + // let mut i: usize = 0; + //DOCUMENTED SIDE-EFFECT + //Performance out of the IO without poll + // * block on the socket until it's readable + // * set the socket to non blocking + // * read until it fails + // * set it back to blocking before returning + // socket.set_nonblocking(false)?; + let batch_capacity = batch.capacity(); + assert!(batch_capacity >= PACKETS_PER_BATCH); + + let mut i = 0; + + while !exit.load(Ordering::Relaxed) { + let npkts = match triton_recv_mmsg(socket, available_frame_buf_vec, batch) { + Ok(npkts) => npkts, + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + // Drain complete for now. Preserve packets already received in this call. + return Ok(i); + } + Err(e) => return Err(e), + }; + i += npkts; + if available_frame_buf_vec.is_empty() { + break; + } + if batch.len() >= batch_capacity { + break; + } + // Try to batch into big enough buffers + // will cause less re-shuffling later on. + if i >= PACKETS_PER_BATCH { + break; + } + } + Ok(i) +} + +#[derive(Debug)] +#[repr(C)] +pub struct TritonPacket { + pub buffer: FrameBuf, + pub meta: Meta, +} + +impl TritonPacket { + pub fn new(buffer: FrameBuf) -> Self { + Self { + buffer, + meta: Meta::default(), + } + } + + pub fn meta_mut(&mut self) -> &mut Meta { + &mut self.meta + } +} + +impl AsRef<[u8]> for TritonPacket { + fn as_ref(&self) -> &[u8] { + self.buffer.chunk() + } +} + +pub fn triton_recv_mmsg( + sock: &UdpSocket, + fill_buffers: &mut Vec, + packets: &mut Vec, +) -> io::Result { + const RECV_BURST_TARGET: usize = 32; + // Should never hit this, but bail if the caller didn't provide any Packets + // to receive into + if fill_buffers.is_empty() { + return Ok(0); + } + // Assert that there are no leftovers in packets. + const SOCKADDR_STORAGE_SIZE: socklen_t = mem::size_of::() as socklen_t; + let mut iovs = [MaybeUninit::uninit(); RECV_BURST_TARGET]; + let mut addrs = [MaybeUninit::zeroed(); RECV_BURST_TARGET]; + let mut hdrs = [MaybeUninit::uninit(); RECV_BURST_TARGET]; + let remaining_packets = packets.capacity() - packets.len(); + let sock_fd = sock.as_raw_fd(); + let count = cmp::min(iovs.len(), remaining_packets).min(fill_buffers.len()); + + if count == 0 { + return Ok(0); + } + + let mut frame_buffer_inflight_vec: [MaybeUninit; RECV_BURST_TARGET] = + std::array::from_fn(|_| MaybeUninit::uninit()); + + let mut frame_buffer_inflight_cnt = 0; + for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(count) { + let buffer = fill_buffers.pop().expect("insufficient fill buffers"); + assert!( + buffer.remaining_mut() >= PACKET_DATA_SIZE, + "fill buffer too small" + ); + let iov_base = unsafe { buffer.as_mut_ptr() as *mut libc::c_void }; + + iov.write(iovec { + iov_base: iov_base, + iov_len: PACKET_DATA_SIZE, + }); + + let msg_hdr = create_msghdr(addr, SOCKADDR_STORAGE_SIZE, iov); + + hdr.write(mmsghdr { + msg_len: 0, + msg_hdr, + }); + // Keep track of the in-flight frame buffers to avoid use-after-free + frame_buffer_inflight_vec[frame_buffer_inflight_cnt].write(buffer); + frame_buffer_inflight_cnt += 1; + } + + let mut ts = libc::timespec { + tv_sec: 1, + tv_nsec: 0, + }; + // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl + #[allow(clippy::useless_conversion)] + let nrecv = unsafe { + libc::recvmmsg( + sock_fd, + hdrs[0].assume_init_mut(), + count as u32, + MSG_DONTWAIT.try_into().unwrap(), + &mut ts, + ) + }; + let nrecv = if nrecv < 0 { + // On error, return all in-flight frame buffers back to the caller + for i in 0..frame_buffer_inflight_cnt { + let buffer = unsafe { frame_buffer_inflight_vec[i].assume_init_read() }; + fill_buffers.push(buffer); + } + return Err(io::Error::last_os_error()); + } else { + usize::try_from(nrecv).unwrap() + }; + + + for idx in 0..nrecv { + // SAFETY: `nrecv <= count` and we initialized `count` entries in `hdrs`. + let hdr_ref = unsafe { hdrs[idx].assume_init_ref() }; + // SAFETY: Same argument as above for `addrs`. + let addr_ref = unsafe { addrs[idx].assume_init_ref() }; + let mut filled_bufmut = unsafe { frame_buffer_inflight_vec[idx].assume_init_read() }; + unsafe { filled_bufmut.seek(hdr_ref.msg_len as usize); } + let filled_buf: FrameBuf = filled_bufmut.into(); + let mut pkt = TritonPacket { + buffer: filled_buf, + meta: Meta::default(), + }; + pkt.meta_mut().size = hdr_ref.msg_len as usize; + if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { + pkt.meta_mut().set_socket_addr(&addr); + } + packets.push(pkt); + } + + if nrecv != count { + log::debug!( + "triton_recv_mmsg: recvd {nrecv} packets, expected up to {count}. Remaining fill buffers returned to caller." + ); + } + + // // Return submitted buffers that were not filled by this syscall. + for in_flight in &mut frame_buffer_inflight_vec[nrecv..count] + { + let buffer = unsafe { in_flight.assume_init_read() }; + fill_buffers.push(buffer); + } + + for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) { + // SAFETY: We initialized `count` elements of each array above + // + // It may be that `packets.len() != NUM_RCVMMSGS`; thus, some elements + // in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must + // manually drop `count` elements from each array instead of being able + // to convert [MaybeUninit] to [T] and letting `Drop` do the work + // for us when these items go out of scope at the end of the function + unsafe { + iov.assume_init_drop(); + addr.assume_init_drop(); + hdr.assume_init_drop(); + } + } + + Ok(nrecv) +} + +fn create_msghdr( + msg_name: &mut MaybeUninit, + msg_namelen: socklen_t, + iov: &mut MaybeUninit, +) -> msghdr { + // Cannot construct msghdr directly on musl + // See https://github.com/rust-lang/libc/issues/2344 for more info + let mut msg_hdr: msghdr = unsafe { zeroed() }; + msg_hdr.msg_name = msg_name.as_mut_ptr() as *mut _; + msg_hdr.msg_namelen = msg_namelen; + msg_hdr.msg_iov = iov.as_mut_ptr(); + msg_hdr.msg_iovlen = 1; + msg_hdr.msg_control = std::ptr::null::() as *mut _; + msg_hdr.msg_controllen = 0; + msg_hdr.msg_flags = 0; + msg_hdr +} + +fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option { + use libc::{sa_family_t, sockaddr_in, sockaddr_in6}; + const SOCKADDR_IN_SIZE: usize = std::mem::size_of::(); + const SOCKADDR_IN6_SIZE: usize = std::mem::size_of::(); + if addr.ss_family == AF_INET as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L167-L172 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in) }; + return Some(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()), + u16::from_be(addr.sin_port), + ))); + } + if addr.ss_family == AF_INET6 as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN6_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L174-L189 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in6) }; + return Some(SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(addr.sin6_addr.s6_addr), + u16::from_be(addr.sin6_port), + addr.sin6_flowinfo, + addr.sin6_scope_id, + ))); + } + error!( + "recvmmsg unexpected ss_family:{} msg_namelen:{}", + addr.ss_family, hdr.msg_hdr.msg_namelen + ); + None +} + +#[cfg(test)] +mod tests { + use super::*; + use libc::{sa_family_t, sockaddr_in, sockaddr_in6}; + use std::thread; + + #[test] + fn test_create_msghdr_fields() { + let mut addr = MaybeUninit::::zeroed(); + let mut iov = MaybeUninit::::uninit(); + let namelen = std::mem::size_of::() as socklen_t; + let hdr = create_msghdr(&mut addr, namelen, &mut iov); + + assert_eq!(hdr.msg_name, addr.as_mut_ptr() as *mut _); + assert_eq!(hdr.msg_namelen, namelen); + assert_eq!(hdr.msg_iov, iov.as_mut_ptr()); + assert_eq!(hdr.msg_iovlen, 1); + } + + #[test] + fn test_cast_socket_addr_ipv4() { + let ip = Ipv4Addr::new(1, 2, 3, 4); + let port = 12345u16; + + let mut storage: sockaddr_storage = unsafe { zeroed() }; + let sin = sockaddr_in { + sin_family: AF_INET as sa_family_t, + sin_port: port.to_be(), + sin_addr: libc::in_addr { + s_addr: u32::from_ne_bytes(ip.octets()), + }, + sin_zero: [0; 8], + }; + unsafe { + std::ptr::write(&mut storage as *mut _ as *mut sockaddr_in, sin); + } + + let mut hdr: mmsghdr = unsafe { zeroed() }; + hdr.msg_hdr.msg_namelen = std::mem::size_of::() as socklen_t; + + let out = cast_socket_addr(&storage, &hdr); + assert_eq!(out, Some(SocketAddr::V4(SocketAddrV4::new(ip, port)))); + } + + #[test] + fn test_cast_socket_addr_ipv6() { + let ip = Ipv6Addr::LOCALHOST; + let port = 54321u16; + let flowinfo = 7u32; + let scope_id = 9u32; + + let mut storage: sockaddr_storage = unsafe { zeroed() }; + let sin6 = sockaddr_in6 { + sin6_family: AF_INET6 as sa_family_t, + sin6_port: port.to_be(), + sin6_flowinfo: flowinfo, + sin6_addr: libc::in6_addr { + s6_addr: ip.octets(), + }, + sin6_scope_id: scope_id, + }; + unsafe { + std::ptr::write(&mut storage as *mut _ as *mut sockaddr_in6, sin6); + } + + let mut hdr: mmsghdr = unsafe { zeroed() }; + hdr.msg_hdr.msg_namelen = std::mem::size_of::() as socklen_t; + + let out = cast_socket_addr(&storage, &hdr); + assert_eq!( + out, + Some(SocketAddr::V6(SocketAddrV6::new(ip, port, flowinfo, scope_id))) + ); + } + + #[test] + fn test_cast_socket_addr_invalid_returns_none() { + let storage: sockaddr_storage = unsafe { zeroed() }; + let mut hdr: mmsghdr = unsafe { zeroed() }; + hdr.msg_hdr.msg_namelen = 0; + assert_eq!(cast_socket_addr(&storage, &hdr), None); + } + + #[test] + fn test_hash_pair_is_deterministic_and_mixes_inputs() { + let h1 = hash_pair(123, 456); + let h2 = hash_pair(123, 456); + let h3 = hash_pair(124, 456); + let h4 = hash_pair(123, 457); + assert_eq!(h1, h2); + assert_ne!(h1, h3); + assert_ne!(h1, h4); + } + + #[test] + fn test_triton_recv_mmsg_with_udp_socket() { + let recv_sock = match UdpSocket::bind("127.0.0.1:0") { + Ok(s) => s, + Err(e) if e.kind() == io::ErrorKind::PermissionDenied => return, + Err(e) => panic!("failed to bind recv socket: {e}"), + }; + let send_sock = match UdpSocket::bind("127.0.0.1:0") { + Ok(s) => s, + Err(e) if e.kind() == io::ErrorKind::PermissionDenied => return, + Err(e) => panic!("failed to bind send socket: {e}"), + }; + let recv_addr = recv_sock.local_addr().unwrap(); + let payload = b"hello-recv-mmsg"; + + send_sock.send_to(payload, recv_addr).unwrap(); + + let shmem = crate::mem::SharedMem::new(PACKET_DATA_SIZE, 1, false).unwrap(); + let mut fill_buffers = vec![FrameDesc { + ptr: shmem.ptr, + frame_size: PACKET_DATA_SIZE, + shmem_idx: 0, + } + .as_mut_buf()]; + let mut packets = Vec::::with_capacity(NUM_RCVMMSGS); + + let mut recv_count = 0usize; + for _ in 0..30 { + match triton_recv_mmsg(&recv_sock, &mut fill_buffers, &mut packets) { + Ok(n) => { + recv_count = n; + if n > 0 { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => panic!("triton_recv_mmsg failed: {e}"), + } + thread::sleep(Duration::from_millis(5)); + } + + assert_eq!(recv_count, 1); + assert_eq!(packets.len(), 1); + let pkt = &packets[0]; + assert_eq!(pkt.meta.size, payload.len()); + assert_eq!(&pkt.buffer.chunk()[..payload.len()], payload); + } + + #[test] + fn test_triton_recv_mmsg_returns_unused_buffers() { + let recv_sock = match UdpSocket::bind("127.0.0.1:0") { + Ok(s) => s, + Err(e) if e.kind() == io::ErrorKind::PermissionDenied => panic!("skipping test_triton_recv_mmsg_returns_unused_buffers due to lack of permissions to bind UDP socket"), + Err(e) => panic!("failed to bind recv socket: {e}"), + }; + let send_sock = match UdpSocket::bind("127.0.0.1:0") { + Ok(s) => s, + Err(e) if e.kind() == io::ErrorKind::PermissionDenied => panic!("skipping test_triton_recv_mmsg_returns_unused_buffers due to lack of permissions to bind UDP socket"), + Err(e) => panic!("failed to bind send socket: {e}"), + }; + let recv_addr = recv_sock.local_addr().unwrap(); + let payload = b"one-packet-only"; + send_sock.send_to(payload, recv_addr).unwrap(); + + let shmem = crate::mem::SharedMem::new(PACKET_DATA_SIZE, 2, false).unwrap(); + let mut fill_buffers = vec![ + FrameDesc { + ptr: shmem.ptr, + frame_size: PACKET_DATA_SIZE, + shmem_idx: 0, + } + .as_mut_buf(), + FrameDesc { + ptr: unsafe { shmem.ptr.add(PACKET_DATA_SIZE) }, + frame_size: PACKET_DATA_SIZE, + shmem_idx: 0, + } + .as_mut_buf(), + ]; + let mut packets = Vec::::with_capacity(NUM_RCVMMSGS); + + let mut recv_count = 0usize; + for _ in 0..30 { + match triton_recv_mmsg(&recv_sock, &mut fill_buffers, &mut packets) { + Ok(n) => { + recv_count = n; + if n > 0 { + break; + } + } + Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} + Err(e) => panic!("triton_recv_mmsg failed: {e}"), + } + thread::sleep(Duration::from_millis(5)); + } + + assert_eq!(recv_count, 1); + assert_eq!(packets.len(), 1); + // One buffer used by the received packet, one should be returned. + assert_eq!(fill_buffers.len(), 1); + } +} diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs new file mode 100644 index 00000000..ae90e476 --- /dev/null +++ b/proxy/src/triton_forwarder.rs @@ -0,0 +1,750 @@ +use std::{ + collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, os::fd::AsRawFd, str::FromStr, sync::{ + Arc, Mutex, atomic::{AtomicBool, Ordering} + }, thread::JoinHandle, time::{Duration, Instant} +}; + +use arc_swap::ArcSwap; +use bytes::Buf; +use crossbeam_channel::{Receiver, Sender}; +use itertools::izip; +use libc; +use log::{debug, error, info, warn}; +use mio::Waker; +use solana_net_utils::SocketConfig; +use solana_perf::deduper::Deduper; +use solana_streamer::{ + sendmmsg::{batch_send, SendPktsError}, + streamer::{StreamerReceiveStats}, +}; + +use crate::{ + forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ + inc_packets_deduped, inc_packets_forward_failed, observe_dedup_time, observe_recv_interval, observe_send_duration, observe_send_packet_count + }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, multicast_config::TritonMulticastConfig +}; + +// values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 +pub const DEDUPER_FALSE_POSITIVE_RATE: f64 = 0.001; +pub const DEDUPER_NUM_BITS: u64 = 637_534_199; // 76MB +pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); +pub const IP_MULTICAST_TTL: u32 = 8; + +#[derive(Debug, Clone, Copy, Default)] +pub enum PktRecvMemSizing { + #[default] + XSmall = 134217728, // 128MiB + Small = 268435456, // 256MiB + Medium = 536870912, // 512MiB + Large = 1073741824, // 1GiB + XLarge = 2147483648, // 2GiB + XXLarge = 4294967296, // 4GiB + XXXLarge = 8589934592, // 8GiB + XXXXLarge = 17179869184, // 16GiB + XXXXXLarge = 34359738368, // 32GiB +} + +#[derive(Debug, thiserror::Error)] +#[error("Invalid ReceiverMemoryCapacity: {0}")] +pub struct ReceiverMemoryCapacityFromStrErr(String); + +impl FromStr for PktRecvMemSizing { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "xsmall" | "xs" => Ok(PktRecvMemSizing::XSmall), + "small" | "s" => Ok(PktRecvMemSizing::Small), + "medium" | "m" => Ok(PktRecvMemSizing::Medium), + "large" | "l" => Ok(PktRecvMemSizing::Large), + "xlarge" | "xl" => Ok(PktRecvMemSizing::XLarge), + "xxlarge" | "xxl" | "2xl" => Ok(PktRecvMemSizing::XXLarge), + "xxxlarge" | "xxxl" | "3xl" => Ok(PktRecvMemSizing::XXXLarge), + "xxxxlarge" | "xxxxl" | "4xl" => Ok(PktRecvMemSizing::XXXXLarge), + "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(PktRecvMemSizing::XXXXXLarge), + _ => Err(s.to_string()), + } + } +} + +#[derive(Clone, Debug)] +pub struct PktRecvTileMemConfig { + pub frame_size: usize, + pub memory_size: PktRecvMemSizing, + pub hugepage: bool, +} + +impl Default for PktRecvTileMemConfig { + fn default() -> Self { + Self { + frame_size: 2048, + memory_size: PktRecvMemSizing::default(), + hugepage: false, + } + } +} + +fn packet_recv_tile( + pkt_recv_idx: usize, + pkt_recv_socket_vec: Vec, + exit: Arc, + forwarder_stats: Arc, + mut fill_rx: Rx, + packet_tx_vec: Vec>, + wake_slot: Arc>>>, + packet_router: R, + tile_drop_sig: TileClosedSignal, +) -> std::io::Result> +where + R: PacketRoutingStrategy + Send + 'static, +{ + std::thread::Builder::new() + .name(format!("ssListen{pkt_recv_idx}")) + .spawn(move || { + crate::recv_mmsg::recv_loop( + pkt_recv_socket_vec, + &exit, + &forwarder_stats, + &mut fill_rx, + &packet_tx_vec, + wake_slot, + packet_router, + ) + .expect("recv_loop"); + drop(tile_drop_sig); + }) +} + +#[derive(Clone, Debug)] +#[repr(C)] +pub struct SharedMemInfo { + pub start_ptr: *const u8, + pub len: usize, // always a power of 2 +} + +unsafe impl Send for SharedMemInfo {} +unsafe impl Sync for SharedMemInfo {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TileKind { + PktRecv, + PktFwd, +} + +struct TileClosedSignal { + kind: TileKind, + idx: usize, + tx: Option>, +} + +struct TileWaitGroup { + rx: Receiver<(TileKind, usize)>, + tx: Sender<(TileKind, usize)>, +} + +impl TileWaitGroup { + fn new() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self { rx, tx } + } + + fn get_tile_closed_signal(&self, kind: TileKind, idx: usize) -> TileClosedSignal { + TileClosedSignal { + kind, + idx, + tx: Some(self.tx.clone()), + } + } + + fn wait_first(self) -> (TileKind, usize) { + drop(self.tx); + self.rx.recv().expect("TileWaitGroup::wait_first") + } +} + +impl Drop for TileClosedSignal { + fn drop(&mut self) { + if let Some(tx) = &self.tx { + let _ = tx.send((self.kind, self.idx)); + } + } +} + +fn packet_fwd_tile( + packet_fwd_idx: usize, + hot_dest_vec: Arc>>, + send_socket: UdpSocket, + mut packet_rx: Rx, + fill_tx_vec: Vec>, + shmem_info_vec: Vec, + stats: Arc, + exit: Arc, + tile_drop_sig: TileClosedSignal, +) -> std::io::Result> { + std::thread::Builder::new() + .name(format!("ssPxyTx_{packet_fwd_idx}")) + .spawn(move || { + let mut deduper = Deduper::<2, [u8]>::new(&mut rand::thread_rng(), DEDUPER_NUM_BITS); + const UIO_MAXIOV: usize = libc::UIO_MAXIOV as usize; + // We allocate double size to account for possible overflow if destinations array is really big + let mut next_batch_send: Vec<(FrameBuf, SocketAddr)> = Vec::with_capacity(UIO_MAXIOV); + let mut queued: VecDeque = VecDeque::with_capacity(UIO_MAXIOV); + + let mut last_batch_to_send = Instant::now(); + + assert_eq!(fill_tx_vec.len(), shmem_info_vec.len()); + + let mut next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + let mut recycled_frames: Vec = Vec::new(); + + while !exit.load(Ordering::Relaxed) { + if next_deduper_reset_attempt.elapsed() > Duration::ZERO { + deduper.maybe_reset( + &mut rand::thread_rng(), + DEDUPER_FALSE_POSITIVE_RATE, + DEDUPER_RESET_CYCLE, + ); + next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + log::debug!( + "send_batch_count: {}, duplicate: {}, total-pkt-sent: {}, queue-len: {}, to-recycle: {}", + stats.send_batch_count.load(Ordering::Relaxed), + stats.duplicate.load(Ordering::Relaxed), + stats.send_batch_size_sum.load(Ordering::Relaxed), + queued.len(), + recycled_frames.len(), + ); + } + + // Drain packet_rx as fast as possible until queued is full or no packet is available. + while queued.len() < UIO_MAXIOV { + let Some(packet) = packet_rx.try_recv() else { + break; + }; + + let data_size = packet.meta.size; + let data_slice = &packet.buffer.chunk()[..data_size]; + let t = Instant::now(); + if deduper.dedup(data_slice) { + let desc = packet.buffer.into_inner(); + recycled_frames.push(desc); + stats.duplicate.fetch_add(1, Ordering::Relaxed); + inc_packets_deduped(1); + } else { + queued.push_back(packet); + } + let dedup_duration = t.elapsed(); + observe_dedup_time(dedup_duration.as_micros() as f64); + } + + let dests = hot_dest_vec.load(); + let dests_len = dests.len(); + assert!( + dests_len <= UIO_MAXIOV, + "number of destinations ({}) cannot be greater than UIO_MAXIOV ({})", + dests_len, + UIO_MAXIOV + ); + + // Send as much as possible from queued. + while !queued.is_empty() { + next_batch_send.clear(); + + while next_batch_send.len() < UIO_MAXIOV && !queued.is_empty() { + let remaining = UIO_MAXIOV - next_batch_send.len(); + if dests_len > remaining { + break; + } + + let Some(packet) = queued.pop_front() else { + break; + }; + let buf = packet.buffer; + let desc = unsafe { buf.detach_desc() }; + recycled_frames.push(desc); + + for dest in dests.iter() { + let buf_clone = + unsafe { buf.unsafe_subslice_clone(0, packet.meta.size) }; + next_batch_send.push((buf_clone, *dest)); + } + } + + if next_batch_send.is_empty() { + break; + } + + let batch_send_ts = Instant::now(); + let e = last_batch_to_send.elapsed(); + last_batch_to_send = Instant::now(); + + observe_recv_interval(e.as_micros() as f64); + match batch_send(&send_socket, &next_batch_send) { + Ok(_) => { + let send_duration = batch_send_ts.elapsed(); + stats + .batch_send_time_spent + .fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); + stats.send_batch_count.fetch_add(1, Ordering::Relaxed); + stats + .send_batch_size_sum + .fetch_add(next_batch_send.len() as u64, Ordering::Relaxed); + observe_send_duration(send_duration.as_micros() as f64); + observe_send_packet_count(next_batch_send.len() as f64); + } + Err(SendPktsError::IoError(err, num_failed)) => { + error!( + "Failed to send batch of size {}. {num_failed} packets failed. Error: {err}", + next_batch_send.len() + ); + inc_packets_forward_failed(num_failed as u64); + } + } + } + + // Recycle all used frames. + while let Some(desc) = recycled_frames.pop() { + fill_tx_vec[desc.shmem_idx] + .send(desc) + .expect("frame recycling"); + } + + if queued.is_empty() && next_batch_send.is_empty() && recycled_frames.is_empty() { + std::thread::yield_now(); + } + } + log::info!("Exiting pkt_fwd_tile {}", packet_fwd_idx); + drop(tile_drop_sig); + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn run_proxy_system( + pkt_recv_tile_mem_config: PktRecvTileMemConfig, + dest_addr_vec: Arc>>, + multticast_config: Option, + src_ip: IpAddr, + src_port: u16, + num_pkt_recv_tiles: usize, + num_pkt_fwd_tiles: usize, + pkt_router: R, + exit: Arc, + pk_recv_stats: Arc, + pk_fwd_stats: Arc, + doublezero_sk_vec: Vec, +) where + R: PacketRoutingStrategy + Send + Sync + 'static, +{ + assert!(num_pkt_recv_tiles > 0, "num_pkt_recv_tiles must be > 0"); + assert!(num_pkt_fwd_tiles > 0, "num_pkt_fwd_tiles must be > 0"); + let mut tile_thread_vec: Vec> = Vec::new(); + // Build pkt_recv sockets + let pkt_recv_multicast_sk_vec = if let Some(multicast_config) = multticast_config { + log::info!("Using Triton multicast configuration for pkt_recv tiles"); + let vec = crate::multicast_config::create_multicast_sockets_triton( + &multicast_config, + ).expect("multicast-config"); + Some(vec![vec]) + } else { + None + }; + + assert!(doublezero_sk_vec.len() <= num_pkt_recv_tiles, "doublezero_v4_sk_vec.len() ({}) > num_pkt_recv_tiles ({})", doublezero_sk_vec.len(), num_pkt_recv_tiles); + + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( + src_ip, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_pkt_recv_tiles, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); + assert!(pkt_recv_sk_vec.len() == num_pkt_recv_tiles, "pkt_recv_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", pkt_recv_sk_vec.len(), num_pkt_recv_tiles); + + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + assert!(multicast_sk_vec.len() == num_pkt_recv_tiles, "multicast_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", multicast_sk_vec.len(), num_pkt_recv_tiles); + } + + // Make sure socket are set to nonblocking + for sk in &pkt_recv_sk_vec { + sk.set_nonblocking(true).expect("pkt_recv_sk nonblocking"); + } + + + let mut pkt_recv_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + for sk in &pkt_recv_sk_vec { + pkt_recv_sk_raw_fd_vec.push(sk.as_raw_fd()); + } + + let num_frames = + pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; + let frame_size = pkt_recv_tile_mem_config.frame_size; + + let tile_wait_group = TileWaitGroup::new(); + let mut shmem_info_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_tx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_rx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + let mut shmem_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + + let mut pkt_fwd_sk_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + let mut pkt_fwd_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + + let mut packet_rx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + let mut packet_tx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + let mut recv_wake_slots: Vec>>>> = + Vec::with_capacity(num_pkt_recv_tiles); + + // Create the shared memory regions for recv tiles + for shmem_idx in 0..num_pkt_recv_tiles { + assert!( + num_frames.is_power_of_two(), + "num_frames must be a power of 2" + ); + assert!( + frame_size.is_power_of_two(), + "frame_size must be a power of 2" + ); + let shmem = SharedMem::new(frame_size, num_frames, pkt_recv_tile_mem_config.hugepage) + .expect("SharedMem::new"); + log::info!( + "Created shared memory region with frame_size={} num_frames={} total_size={} hugepage={}", + frame_size, + num_frames, + shmem.len(), + pkt_recv_tile_mem_config.hugepage, + ); + + let shmem_info = SharedMemInfo { + start_ptr: shmem.ptr, + len: shmem.len(), + }; + shmem_info_vec.push(shmem_info); + + let (fill_tx, fill_rx) = crate::mem::message_ring(num_frames).expect("frame ring"); + // Fill the fill ring with all frames + for i in 0..num_frames { + let frame_desc = FrameDesc { + ptr: unsafe { shmem.ptr.add(i * frame_size) }, + frame_size: frame_size, + shmem_idx, + }; + fill_tx + .send(frame_desc) + .expect("initial frame ring population"); + } + shmem_vec.push(shmem); + fill_tx_vec.push(fill_tx); + fill_rx_vec.push(fill_rx); + log::info!("Initialized frame ring with {} frames", num_frames); + } + + // Create socket for sending packets + for _ in 0..num_pkt_fwd_tiles { + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket + .set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let socket = UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed"); + socket + .set_multicast_ttl_v4(IP_MULTICAST_TTL) + .expect("IP_MULTICAST_TTL_V4"); + socket + .set_multicast_loop_v4(false) + .expect("Failed to disable IPv4 multicast loopback"); + socket + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } + } + }; + log::info!( + "Packet forwarder sending socket bound to {}", + send_socket.local_addr().unwrap() + ); + pkt_fwd_sk_raw_fd_vec.push(send_socket.as_raw_fd()); + pkt_fwd_sk_vec.push(send_socket); + } + + + let pkt_fwd_tile_ring_capacity = num_frames * num_pkt_recv_tiles; + log::info!( + "Setting pkt_fwd tile's message ring capacity to {} (num_frames {} * num_pkt_recv_tiles {})", + pkt_fwd_tile_ring_capacity, + num_frames, + num_pkt_recv_tiles + ); + + // Create pkt_fwd message rings + // One ring per pkt_fwd tile + for _ in 0..num_pkt_fwd_tiles { + // Worst case scenario all frames from all pkt_recv tiles are sent to this pkt_fwd tile + // We set the ring capacity to that + let (packet_tx, packet_rx) = crate::mem::message_ring(pkt_fwd_tile_ring_capacity).expect("pkt_fwd ring"); + packet_tx_vec.push(packet_tx); + packet_rx_vec.push(packet_rx); + } + + // Spawn pkt_fwd tiles + for (pkt_fwd_idx, pkt_fwd_sk, packet_rx) in izip!( + 0..num_pkt_fwd_tiles, + pkt_fwd_sk_vec.into_iter(), + packet_rx_vec.into_iter() + ) { + let hot_dest_vec = Arc::clone(&dest_addr_vec); + let fill_tx_vec = fill_tx_vec.clone(); + let shmem_info_vec = shmem_info_vec.clone(); + let exit = Arc::clone(&exit); + let th = packet_fwd_tile( + pkt_fwd_idx, + hot_dest_vec, + pkt_fwd_sk, + packet_rx, + fill_tx_vec, + shmem_info_vec, + Arc::clone(&pk_fwd_stats), + exit, + tile_wait_group.get_tile_closed_signal(TileKind::PktFwd, pkt_fwd_idx), + ) + .expect("packet_fwd_tile"); + tile_thread_vec.push(th); + log::info!("Spawned pkt_fwd tile {}", pkt_fwd_idx); + } + + // Spawn pkt_recv tiles + for (pkt_recv_idx, pkt_recv_sk, fill_rx) in izip!( + 0..num_pkt_recv_tiles, + pkt_recv_sk_vec.into_iter(), + fill_rx_vec.into_iter() + ) { + + let mut recv_pkt_vec = vec![ + pkt_recv_sk + ]; + + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + recv_pkt_vec.push(multicast_sk_vec[pkt_recv_idx].try_clone().expect("multicast sk clone")); + } + + if let Some(doublezero_v4_sk) = doublezero_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_v4_sk.try_clone().expect("doublezero v4 sk clone")); + } + + let exit = Arc::clone(&exit); + let forwarder_stats = Arc::clone(&pk_recv_stats); + let packet_tx_vec_clone = packet_tx_vec.clone(); + let wake_slot: Arc>>> = Arc::new(Mutex::new(None)); + recv_wake_slots.push(Arc::clone(&wake_slot)); + let pkt_router_clone = pkt_router.clone(); + let jh = packet_recv_tile( + pkt_recv_idx, + recv_pkt_vec, + exit, + forwarder_stats, + fill_rx, + packet_tx_vec_clone, + wake_slot, + pkt_router_clone, + tile_wait_group.get_tile_closed_signal(TileKind::PktRecv, pkt_recv_idx), + ) + .expect("packet_recv_tile"); + tile_thread_vec.push(jh); + log::info!("Spawned pkt_recv tile {}", pkt_recv_idx); + } + + + let (kind, idx) = tile_wait_group.wait_first(); + warn!("Tile of kind {kind:?} with idx {idx} has exited. Shutting down proxy system"); + + exit.store(true, Ordering::Release); + for wake_slot in &recv_wake_slots { + if let Some(waker) = wake_slot + .lock() + .expect("recv wake slot lock poisoned") + .as_ref() + .cloned() + { + let _ = waker.wake(); + } + } + drop(fill_tx_vec); + drop(packet_tx_vec); + log::info!("Waiting for {} tile threads to exit", tile_thread_vec.len()); + + for th in tile_thread_vec { + let result = th.join(); + if let Err(e) = result { + error!("Tile thread join error: {:?}", e); + } + } +} + +/// Reset dedup + send metrics to influx +pub fn start_forwarder_accessory_thread( + metrics: Arc, + metrics_update_interval_ms: u64, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> JoinHandle<()> { + std::thread::Builder::new() + .name("ssPxyAccessory".to_string()) + .spawn(move || { + let metrics_tick = + crossbeam_channel::tick(Duration::from_millis(metrics_update_interval_ms)); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // send metrics to influx + recv(metrics_tick) -> _ => { + metrics.report(); + metrics.reset(); + } + + // handle SIGINT shutdown + recv(shutdown_receiver) -> _ => { + break; + } + } + } + }) + .unwrap() +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BufMut; + use solana_sdk::packet::PACKET_DATA_SIZE; + use std::net::UdpSocket; + + #[test] + fn test_pkt_recv_mem_sizing_from_str_aliases() { + assert!(matches!( + PktRecvMemSizing::from_str("xs"), + Ok(PktRecvMemSizing::XSmall) + )); + assert!(matches!( + PktRecvMemSizing::from_str("medium"), + Ok(PktRecvMemSizing::Medium) + )); + assert!(matches!( + PktRecvMemSizing::from_str("2xl"), + Ok(PktRecvMemSizing::XXLarge) + )); + assert!(matches!( + PktRecvMemSizing::from_str("5XL"), + Ok(PktRecvMemSizing::XXXXXLarge) + )); + } + + #[test] + fn test_pkt_recv_mem_sizing_from_str_invalid() { + let invalid = "huge"; + let err = PktRecvMemSizing::from_str(invalid).unwrap_err(); + assert_eq!(err, invalid); + } + + #[test] + fn test_pkt_recv_tile_mem_config_default() { + let cfg = PktRecvTileMemConfig::default(); + assert_eq!(cfg.frame_size, 2048); + assert!(matches!(cfg.memory_size, PktRecvMemSizing::XSmall)); + assert!(!cfg.hugepage); + } + + #[test] + fn test_tile_wait_group_wait_first_reports_drop() { + let wait_group = TileWaitGroup::new(); + let sig = wait_group.get_tile_closed_signal(TileKind::PktFwd, 7); + drop(sig); + let (kind, idx) = wait_group.wait_first(); + assert_eq!(kind, TileKind::PktFwd); + assert_eq!(idx, 7); + } + + #[test] + fn test_packet_fwd_tile_sends_and_recycles_frame() { + let frame_size = 2048usize; + let listener = UdpSocket::bind("127.0.0.1:0").expect("listener bind"); + listener + .set_read_timeout(Some(Duration::from_millis(1000))) + .expect("listener set_read_timeout"); + let listener_addr = listener.local_addr().expect("listener local_addr"); + + let send_socket = UdpSocket::bind("0.0.0.0:0").expect("send bind"); + let hot_dest_vec = Arc::new(ArcSwap::from_pointee(vec![listener_addr])); + + let shmem = SharedMem::new(frame_size, 1, false).expect("shmem"); + let frame_desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + shmem_idx: 0, + }; + + let mut frame_bufmut = frame_desc.as_mut_buf(); + let payload = b"hello-forwarder"; + frame_bufmut.put_slice(payload); + let frame_buf: FrameBuf = frame_bufmut.into(); + + let mut packet = TritonPacket::new(frame_buf); + packet.meta_mut().size = payload.len(); + // Use a non-local origin so the destination filter doesn't skip forwarding. + packet + .meta_mut() + .set_socket_addr(&SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 12345)); + + let (fill_tx, mut fill_rx) = crate::mem::message_ring::(8).expect("fill ring"); + let (packet_tx, packet_rx) = + crate::mem::message_ring::(8).expect("packet ring"); + + let shmem_info_vec = vec![SharedMemInfo { + start_ptr: shmem.ptr, + len: shmem.len(), + }]; + let fill_tx_vec = vec![fill_tx]; + + let wait_group = TileWaitGroup::new(); + let exit = Arc::new(AtomicBool::new(false)); + let stats = Arc::new(ShredMetrics::default()); + let jh = packet_fwd_tile( + 0, + hot_dest_vec, + send_socket, + packet_rx, + fill_tx_vec, + shmem_info_vec, + stats, + Arc::clone(&exit), + wait_group.get_tile_closed_signal(TileKind::PktFwd, 0), + ) + .expect("spawn packet_fwd_tile"); + + packet_tx.send(packet).expect("send packet to fwd tile"); + + let mut recv_buf = [0u8; PACKET_DATA_SIZE]; + let (n, _) = listener.recv_from(&mut recv_buf).expect("recv forwarded packet"); + assert_eq!(&recv_buf[..n], payload); + + let recycled = fill_rx + .recv_timeout(Duration::from_millis(1000)) + .expect("recycled frame"); + assert_eq!(recycled.ptr, shmem.ptr); + assert_eq!(recycled.frame_size, frame_size); + + exit.store(true, Ordering::Release); + drop(packet_tx); + let _ = wait_group.wait_first(); + jh.join().expect("join packet_fwd_tile"); + } +} diff --git a/run-triton-proxy.sh b/run-triton-proxy.sh new file mode 100644 index 00000000..9187c0c4 --- /dev/null +++ b/run-triton-proxy.sh @@ -0,0 +1,13 @@ + +# send random traffic +sudo nping --udp -p 8002 -c 0 --rate 1 --data-length 1200 127.0.0.1 + +while true; do + # Generate 1200 random bytes from /dev/urandom and send them + head -c 1200 /dev/urandom | socat - UDP:127.0.0.1:8002 + sleep 0.01 # Add a small delay between packets +done + + +# run triton-proxy +cargo run --bin triton-proxy -- forward-only --src-bind-addr 127.0.0.1 --src-bind-port 8002 --prometheus-bind-addr 127.0.0.1:9999 --dest-ip-ports 127.0.0.1:8989 \ No newline at end of file diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh new file mode 100755 index 00000000..f92d5f7d --- /dev/null +++ b/scripts/build-dist.sh @@ -0,0 +1,10 @@ +#!/bin/bash + + +mkdir -p dist +rm -rf dist/* + +cargo build --release -p jito-shredstream-proxy + +cp target/release/triton-shredproxy dist/triton-shredproxy-ubuntu-22.04 +cp target/release/jito-shredstream-proxy dist/jito-shredstream-proxy-ubuntu-22.04 \ No newline at end of file diff --git a/setup_net.sh b/setup_net.sh new file mode 100755 index 00000000..da732aec --- /dev/null +++ b/setup_net.sh @@ -0,0 +1,26 @@ +set -e +# 1. Create the namespace +sudo ip netns add ns1 + +# 2. Create the virtual cable (veth0 <-> veth1) +sudo ip link add veth0 type veth peer name veth1 + +# 3. Move veth1 end into the namespace +sudo ip link set veth1 netns ns1 + +# 4. Assign IP to Host side (veth0) - The Driver +sudo ip addr add 172.31.0.1/24 dev veth0 +sudo ip link set veth0 up + +# 5. Assign IP to Namespace side (veth1) - The Server +sudo ip netns exec ns1 ip addr add 172.31.0.2/24 dev veth1 +sudo ip netns exec ns1 ip link set veth1 up + +# 6. Turn off offloading (Crucial for AF_XDP) +#sudo ethtool -K veth0 gro off +#sudo ip netns exec ns1 ethtool -K veth1 gro off + +#nping --udp -p 1234 --data-length 500 -c 10 + +# 7. Verify connectivity +ping -c 2 172.31.0.2 \ No newline at end of file