Closing the Sim-to-Real Gap in Network Spreading Processes via Stratified Observation and Distributional Control
Reference implementation for the ICML 2026 paper.
- π Paper: ICML 2026 (accepted)
- β‘ Simulation engine (external dependency): FlashSpread (arXiv:2604.22092)
Controlling spreading processes on networks β epidemics, information cascades, product adoption β requires policies that perform on realistic stochastic dynamics, not just tractable approximations. Yet policies trained on standard simplifications (mean-field ODEs, Markovian dynamics) suffer severe performance degradation at deployment. We trace this sim-to-real gap to three theoretical pathologies:
- Optimism Bias β deterministic approximations systematically underestimate variance via Jensen's inequality.
- Hub Blindness β global state aggregation obscures the super-spreaders driving scale-free networks.
- Valley of Death β mean-value critics cannot navigate the bimodal nature (extinction vs. viral) of cascade outcomes.
We resolve these through two synergistic contributions: the Stratified Mean-Field Observer (fixed-dimensional,
All policies (TQC + Stratified Observer) evaluated on the non-Markovian renewal environment (
| Training Environment | Profit | Std | Gap |
|---|---|---|---|
| MF-ODE | Β± 612 | ||
| Stochastic Markovian | Β± 4 | ||
| Stochastic Renewal (Ours) | Β± 498 | β |
Training on simplified dynamics fails catastrophically at deployment. Mean-field training actively destroys value (negative profit); Markovian training captures only 1.7% of optimal performance.
Data: logs/sim_to_real_transfer.csv Β· Figure: docs/report/figures/figure_sim_to_real.pdf
Policy trained on synthetic BarabΓ‘siβAlbert graphs (
| Network | Nodes | Profit | Hub Retention |
|---|---|---|---|
| 4,039 | 1.00 | ||
| 81,306 | 1.00 | ||
| YouTube | 1,134,890 | $1.76$M | 1.00 |
Our policy is the only approach producing positive returns across all three networks; baseline MF-ODE / Markov / Random policies all yield catastrophic negative profit on real-world topologies.
Figure: docs/report/figures/figure_snap_transfer.pdf
This repository depends on FlashSpread, a dual-engine GPU simulator for network spreading processes. FlashSpread is the algorithmic enabler for this work β sustaining
| Component | Source | Usage in this repo |
|---|---|---|
GraphCSR |
flashspread.core |
Compressed sparse row graph representation |
FlashNeighbor |
flashspread.core |
Triton kernel for neighbor influence computation |
RenewalEngine |
flashspread.engines.renewal |
Non-Markovian simulation with age tracking |
What lives in this repo's GEMF/ directory: the CDM-specific solvers built on top of FlashSpread primitives β sparse, CUDA-graph, renewal, and ODE flavors β plus the RL stack (train.py, td3_agent.py, tqc_agent.py, actors.py, marketing_env.py).
System requirements:
-
CUDA-capable NVIDIA GPU (Volta/SM70+; A100 recommended for
$N \ge 10^5$ ). - PyTorch with matching CUDA build β install per the official selector.
- Triton (pulled in transitively for the CUDAGraph kernel).
# 1. Install PyTorch with CUDA from the official selector
pip install torch
# 2. Install this repo's deps (pulls FlashSpread from GitHub)
pip install -r requirements.txtFor development on FlashSpread itself, clone it side-by-side and use editable installs:
git clone https://github.com/Shakeri-Lab/FlashSpread.git ../FlashSpread
pip install -e ../FlashSpreadCPU-only fallback works for --solver sparse or --solver ode, but throughput is significantly lower.
Train on non-Markovian (renewal) dynamics with the age-augmented stratified observer:
python train.py --solver renewal --observer-age \
--total-steps 40000 --num-nodes 10000 --device cuda \
--renewal-mu-eo 1.5 --renewal-sigma-eo 0.6 \
--renewal-mu-oe 2.5 --renewal-sigma-oe 0.8Train TQC + tier-aware actor (the configuration that produces the headline result):
python train.py --agent tqc --actor tier --solver renewal --observer-age \
--num-tiers 8 --total-steps 40000 --device cudaEvaluate a checkpoint on renewal (deployment-physics) dynamics:
python experiments/eval_checkpoint.py \
--checkpoint checkpoints/<run>/checkpoint_best.pt \
--eval-solver renewal --observer-age \
--episodes 50 --seeds 5graph-rl/
GEMF/ # CDM-specific solvers (build on FlashSpread)
gemf_cdm.py # sparse stochastic CDM solver
gemf_cdm_cudagraph.py # CUDA-graph CDM solver
gemf_cdm_renewal.py # non-Markovian CDM model
gemf_cdm_renewal_solver.py # renewal solver (wraps FlashSpread RenewalEngine)
gemf_ode.py # mean-field ODE solver
actors.py # MLP + TierAware actors (32/64-dim observations)
marketing_env.py # Gymnasium env + stratified/age observers
replay_buffer.py # buffer with profit/churn for hindsight relabeling
td3_agent.py # TD3 implementation
tqc_agent.py # TQC (distributional RL) implementation
train.py # training entrypoint
utils/ # plotting + diagnostics
experiments/ # evaluation, sweeps, transfer, stress tests
slurm/ # reproducible Slurm job scripts
requirements.txt
gemf_cdm.pyβ exact stochastic CDM solver (sparse influence + active-set updates).gemf_cdm_cudagraph.pyβ CUDA-graph CDM solver; uses FlashSpread's TritonFlashNeighborkernel.gemf_cdm_renewal.py+gemf_cdm_renewal_solver.pyβ non-Markovian CDM with log-normal dwell times; wraps FlashSpreadRenewalEngine.gemf_ode.pyβ deterministic node-level mean-field ODE (Euler integration on GPU).
train.pyβ main training entrypoint. Key flags:--solver {sparse, cudagraph, ode, renewal}selects physics engine.--agent {td3, tqc}selects algorithm.--actor {mlp, tier}selects policy head (1D Conv over tiers fortier).--observer-ageenables age-augmented observations (64-dim) for renewal dynamics.
marketing_env.pyβ Gymnasium environment withStratifiedObserver(32-dim) andStratifiedAgeObserver(64-dim).actors.pyβMLPActorandTierAwareActor(Conv1D over tiers).tqc_agent.pyβ Truncated Quantile Critics (25 quantiles, 2 dropped, 5 critics).replay_buffer.pyβ stores rawtrue_profit/churn_countfor hindsight reward relabeling during sampling.
To avoid misleading comparisons, all agents are evaluated on the realistic non-Markovian (renewal) simulator, regardless of how they were trained:
| Training regime | Solver | Evaluation solver |
|---|---|---|
| Mean-Field | ode |
renewal |
| Stochastic Markovian | cudagraph |
renewal |
| Stochastic Renewal | renewal |
renewal |
This explicitly separates training dynamics from deployment dynamics and prevents the common pitfall where ODE-trained policies appear artificially strong when evaluated on the same deterministic solver.
Reference array job: experiments/slurm/run_eval_metastable_matrix.sbatch (5 checkpoints Γ renewal eval).
LaTeX tables (regenerate from CSVs in logs/ to docs/report/tables/):
import pandas as pd, pathlib
pairs = [
("logs/table1_control_efficacy_summary_v2.csv", "docs/report/tables/table1_control_efficacy_summary_v2.tex"),
("logs/table2_critic_comparison_summary_v2.csv", "docs/report/tables/table2_critic_comparison_summary_v2.tex"),
("logs/table4_ablation_summary_v2.csv", "docs/report/tables/table4_ablation_summary_v2.tex"),
("logs/table5_zero_shot_v2.csv", "docs/report/tables/table5_zero_shot_v2.tex"),
("logs/snap_transfer_results.csv", "docs/report/tables/snap_transfer_results.tex"),
]
for src, dst in pairs:
pd.read_csv(src).to_latex(dst, index=False, escape=True)Key figure regeneration scripts:
- Sim-to-Real bar chart:
python utils/plot_sim_to_real.py --input-csv logs/sim_to_real_combined.csv --out docs/report/figures/figure_sim_to_real.pdf - Mean-Field Illusion (Fig. 1):
python utils/plot_mean_field_illusion.py - Hub-Sniper Strategy (Fig. 3):
python utils/plot_hub_sniper_strategy.py - TQC Quantiles (Fig. 2):
python utils/plot_tqc_quantiles.py - Stratified Cascade Heatmap:
python utils/plot_stratified_heatmap.py - Commitment Window:
python experiments/commitment_window_trace.py --out-csv logs/commitment_window_trace.csv --out-plot figures/figure_commitment_window.png
All experiments use a shared visualization style (utils/style_config.py β LaTeX rendering, colorblind-friendly palette, ICML-style typography).
Sensitivity analysis (main paper / appendix):
python utils/plot_sensitivity_analysis.py --main-only # 3-panel figure
python utils/plot_sensitivity_analysis.py --supplement-only # full breakdownSlurm scripts for full retraining and evaluation pipelines live in experiments/slurm/.
If you use this code or build on it, please cite both the ICML paper and the underlying FlashSpread engine:
@inproceedings{simtoreal2026,
title = {Closing the Sim-to-Real Gap in Network Spreading Processes
via Stratified Observation and Distributional Control},
author = {Anonymous},
booktitle = {Proceedings of the 43rd International Conference on Machine Learning (ICML)},
year = {2026}
}
@article{shakeri2026flashspread,
title = {FlashSpread: IO-Aware GPU Simulation of Non-Markovian Epidemic
Dynamics via Kernel Fusion},
author = {Shakeri, Heman and Moradi-Jamei, Behnaz and Vajdi, Aram and Ardjmand, Ehsan},
journal = {arXiv preprint arXiv:2604.22092},
year = {2026}
}This project depends on FlashSpread; see that repository for license and engine-level documentation.