Skip to content

Shakeri-Lab/graph-rl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

13 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Closing the Sim-to-Real Gap in Network Spreading Processes via Stratified Observation and Distributional Control

Reference implementation for the ICML 2026 paper.

Abstract

Controlling spreading processes on networks β€” epidemics, information cascades, product adoption β€” requires policies that perform on realistic stochastic dynamics, not just tractable approximations. Yet policies trained on standard simplifications (mean-field ODEs, Markovian dynamics) suffer severe performance degradation at deployment. We trace this sim-to-real gap to three theoretical pathologies:

  1. Optimism Bias β€” deterministic approximations systematically underestimate variance via Jensen's inequality.
  2. Hub Blindness β€” global state aggregation obscures the super-spreaders driving scale-free networks.
  3. Valley of Death β€” mean-value critics cannot navigate the bimodal nature (extinction vs. viral) of cascade outcomes.

We resolve these through two synergistic contributions: the Stratified Mean-Field Observer (fixed-dimensional, $\mathcal{O}(N)$, hub-preserving) and Distributional RL via Truncated Quantile Critics (TQC) for risk-aware control of bimodal cascades. Trained on FlashSpread β€” a GPU-accelerated simulator with non-Markovian renewal dynamics β€” our approach achieves a 59Γ— improvement over Markovian baselines and zero-shot transfer to real-world social networks (Facebook, Twitter, YouTube) with 100% hub retention.

Headline Results

Sim-to-Real Transfer Gap

All policies (TQC + Stratified Observer) evaluated on the non-Markovian renewal environment ($N=10^5$, 50 episodes):

Training Environment Profit Std Gap
MF-ODE $-238{,}731$ Β± 612 $-651%$
Stochastic Markovian $730$ Β± 4 $-98.3%$
Stochastic Renewal (Ours) $43{,}340$ Β± 498 β€”

Training on simplified dynamics fails catastrophically at deployment. Mean-field training actively destroys value (negative profit); Markovian training captures only 1.7% of optimal performance.

Data: logs/sim_to_real_transfer.csv Β· Figure: docs/report/figures/figure_sim_to_real.pdf

Zero-Shot Transfer to SNAP Networks

Policy trained on synthetic BarabΓ‘si–Albert graphs ($N=10^5$), deployed without retraining:

Network Nodes Profit Hub Retention
Facebook 4,039 $6{,}700$ 1.00
Twitter 81,306 $133{,}100$ 1.00
YouTube 1,134,890 $1.76$M 1.00

Our policy is the only approach producing positive returns across all three networks; baseline MF-ODE / Markov / Random policies all yield catastrophic negative profit on real-world topologies.

Figure: docs/report/figures/figure_snap_transfer.pdf

FlashSpread (External Dependency)

This repository depends on FlashSpread, a dual-engine GPU simulator for network spreading processes. FlashSpread is the algorithmic enabler for this work β€” sustaining $>2 \times 10^7$ events/s for Markovian dynamics and supporting non-Markovian renewal processes that would take weeks on CPU.

Component Source Usage in this repo
GraphCSR flashspread.core Compressed sparse row graph representation
FlashNeighbor flashspread.core Triton kernel for neighbor influence computation
RenewalEngine flashspread.engines.renewal Non-Markovian simulation with age tracking

What lives in this repo's GEMF/ directory: the CDM-specific solvers built on top of FlashSpread primitives β€” sparse, CUDA-graph, renewal, and ODE flavors β€” plus the RL stack (train.py, td3_agent.py, tqc_agent.py, actors.py, marketing_env.py).

Install

System requirements:

  • CUDA-capable NVIDIA GPU (Volta/SM70+; A100 recommended for $N \ge 10^5$).
  • PyTorch with matching CUDA build β€” install per the official selector.
  • Triton (pulled in transitively for the CUDAGraph kernel).
# 1. Install PyTorch with CUDA from the official selector
pip install torch

# 2. Install this repo's deps (pulls FlashSpread from GitHub)
pip install -r requirements.txt

For development on FlashSpread itself, clone it side-by-side and use editable installs:

git clone https://github.com/Shakeri-Lab/FlashSpread.git ../FlashSpread
pip install -e ../FlashSpread

CPU-only fallback works for --solver sparse or --solver ode, but throughput is significantly lower.

Quickstart

Train on non-Markovian (renewal) dynamics with the age-augmented stratified observer:

python train.py --solver renewal --observer-age \
  --total-steps 40000 --num-nodes 10000 --device cuda \
  --renewal-mu-eo 1.5 --renewal-sigma-eo 0.6 \
  --renewal-mu-oe 2.5 --renewal-sigma-oe 0.8

Train TQC + tier-aware actor (the configuration that produces the headline result):

python train.py --agent tqc --actor tier --solver renewal --observer-age \
  --num-tiers 8 --total-steps 40000 --device cuda

Evaluate a checkpoint on renewal (deployment-physics) dynamics:

python experiments/eval_checkpoint.py \
  --checkpoint checkpoints/<run>/checkpoint_best.pt \
  --eval-solver renewal --observer-age \
  --episodes 50 --seeds 5

Repository Layout

graph-rl/
  GEMF/                           # CDM-specific solvers (build on FlashSpread)
    gemf_cdm.py                   # sparse stochastic CDM solver
    gemf_cdm_cudagraph.py         # CUDA-graph CDM solver
    gemf_cdm_renewal.py           # non-Markovian CDM model
    gemf_cdm_renewal_solver.py    # renewal solver (wraps FlashSpread RenewalEngine)
    gemf_ode.py                   # mean-field ODE solver
  actors.py                       # MLP + TierAware actors (32/64-dim observations)
  marketing_env.py                # Gymnasium env + stratified/age observers
  replay_buffer.py                # buffer with profit/churn for hindsight relabeling
  td3_agent.py                    # TD3 implementation
  tqc_agent.py                    # TQC (distributional RL) implementation
  train.py                        # training entrypoint
  utils/                          # plotting + diagnostics
  experiments/                    # evaluation, sweeps, transfer, stress tests
    slurm/                        # reproducible Slurm job scripts
  requirements.txt

Core Components

Physics engines (GEMF/)

  • gemf_cdm.py β€” exact stochastic CDM solver (sparse influence + active-set updates).
  • gemf_cdm_cudagraph.py β€” CUDA-graph CDM solver; uses FlashSpread's Triton FlashNeighbor kernel.
  • gemf_cdm_renewal.py + gemf_cdm_renewal_solver.py β€” non-Markovian CDM with log-normal dwell times; wraps FlashSpread RenewalEngine.
  • gemf_ode.py β€” deterministic node-level mean-field ODE (Euler integration on GPU).

RL pipeline

  • train.py β€” main training entrypoint. Key flags:
    • --solver {sparse, cudagraph, ode, renewal} selects physics engine.
    • --agent {td3, tqc} selects algorithm.
    • --actor {mlp, tier} selects policy head (1D Conv over tiers for tier).
    • --observer-age enables age-augmented observations (64-dim) for renewal dynamics.
  • marketing_env.py β€” Gymnasium environment with StratifiedObserver (32-dim) and StratifiedAgeObserver (64-dim).
  • actors.py β€” MLPActor and TierAwareActor (Conv1D over tiers).
  • tqc_agent.py β€” Truncated Quantile Critics (25 quantiles, 2 dropped, 5 critics).
  • replay_buffer.py β€” stores raw true_profit / churn_count for hindsight reward relabeling during sampling.

Evaluation Protocol

To avoid misleading comparisons, all agents are evaluated on the realistic non-Markovian (renewal) simulator, regardless of how they were trained:

Training regime Solver Evaluation solver
Mean-Field ode renewal
Stochastic Markovian cudagraph renewal
Stochastic Renewal renewal renewal

This explicitly separates training dynamics from deployment dynamics and prevents the common pitfall where ODE-trained policies appear artificially strong when evaluated on the same deterministic solver.

Reference array job: experiments/slurm/run_eval_metastable_matrix.sbatch (5 checkpoints Γ— renewal eval).

Reproducing Paper Figures and Tables

LaTeX tables (regenerate from CSVs in logs/ to docs/report/tables/):

import pandas as pd, pathlib
pairs = [
    ("logs/table1_control_efficacy_summary_v2.csv",      "docs/report/tables/table1_control_efficacy_summary_v2.tex"),
    ("logs/table2_critic_comparison_summary_v2.csv",     "docs/report/tables/table2_critic_comparison_summary_v2.tex"),
    ("logs/table4_ablation_summary_v2.csv",              "docs/report/tables/table4_ablation_summary_v2.tex"),
    ("logs/table5_zero_shot_v2.csv",                     "docs/report/tables/table5_zero_shot_v2.tex"),
    ("logs/snap_transfer_results.csv",                   "docs/report/tables/snap_transfer_results.tex"),
]
for src, dst in pairs:
    pd.read_csv(src).to_latex(dst, index=False, escape=True)

Key figure regeneration scripts:

  • Sim-to-Real bar chart: python utils/plot_sim_to_real.py --input-csv logs/sim_to_real_combined.csv --out docs/report/figures/figure_sim_to_real.pdf
  • Mean-Field Illusion (Fig. 1): python utils/plot_mean_field_illusion.py
  • Hub-Sniper Strategy (Fig. 3): python utils/plot_hub_sniper_strategy.py
  • TQC Quantiles (Fig. 2): python utils/plot_tqc_quantiles.py
  • Stratified Cascade Heatmap: python utils/plot_stratified_heatmap.py
  • Commitment Window: python experiments/commitment_window_trace.py --out-csv logs/commitment_window_trace.csv --out-plot figures/figure_commitment_window.png

All experiments use a shared visualization style (utils/style_config.py β€” LaTeX rendering, colorblind-friendly palette, ICML-style typography).

Sensitivity analysis (main paper / appendix):

python utils/plot_sensitivity_analysis.py --main-only        # 3-panel figure
python utils/plot_sensitivity_analysis.py --supplement-only  # full breakdown

Slurm scripts for full retraining and evaluation pipelines live in experiments/slurm/.

Citation

If you use this code or build on it, please cite both the ICML paper and the underlying FlashSpread engine:

@inproceedings{simtoreal2026,
  title     = {Closing the Sim-to-Real Gap in Network Spreading Processes
               via Stratified Observation and Distributional Control},
  author    = {Anonymous},
  booktitle = {Proceedings of the 43rd International Conference on Machine Learning (ICML)},
  year      = {2026}
}

@article{shakeri2026flashspread,
  title   = {FlashSpread: IO-Aware GPU Simulation of Non-Markovian Epidemic
             Dynamics via Kernel Fusion},
  author  = {Shakeri, Heman and Moradi-Jamei, Behnaz and Vajdi, Aram and Ardjmand, Ehsan},
  journal = {arXiv preprint arXiv:2604.22092},
  year    = {2026}
}

Acknowledgments

This project depends on FlashSpread; see that repository for license and engine-level documentation.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors