Skip to content

Add torch_compile flag for training networks#28

Open
wenxin0319 wants to merge 1 commit into
NVlabs:mainfrom
wenxin0319:main
Open

Add torch_compile flag for training networks#28
wenxin0319 wants to merge 1 commit into
NVlabs:mainfrom
wenxin0319:main

Conversation

@wenxin0319
Copy link
Copy Markdown

FastGen currently relies on diffusers-based model execution, which leaves performance on the table during training.

This PR adds an opt-in torch_compile flag that wraps training networks with torch.compile, enabling PyTorch's compiler optimizations (operator fusion, memory planning, kernel autotuning) for significant speedups on common models.

Benchmark (QwenImage, 20.43B params, NVIDIA H100, bfloat16, 512x512):

Setting │ Time/iter │ Std
Baseline (no compile) │ 0.694s │ 0.094s
torch.compile (max-autotune) │ 0.447s │ 0.014s

which is Speedup 1.55x (55% faster)

Compiled iterations also show much lower variance (0.014s vs 0.094s), meaning more consistent training throughput. The one-time compilation overhead (~5-10 min with max-autotune) is amortized over the full training run.

Changes:

  • Add torch_compile: bool = False config option in BaseModelConfig
  • Add _apply_torch_compile() in FastGenModel that compiles the main network (self.net)
  • Override _apply_torch_compile() in DMD2Model to also compile teacher and fake_score networks
  • Add comprehensive tests covering compile on/off for both SFT and DMD2 models, including training step validation
  • Add bench_compile.py benchmark script for measuring compile speedup

Usage:
Set torch_compile=True in model config to enable.

@wenxin0319
Copy link
Copy Markdown
Author

@juliusberner Could you please take a look at my PR? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant