Skip to content

Commit 6e6fb76

Browse files
better block sizes.
1 parent 3ef352f commit 6e6fb76

3 files changed

Lines changed: 19 additions & 16 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te
5757

5858
#flash_block_sizes: {}
5959
flash_block_sizes: {
60-
"block_q" : 2048,
61-
"block_kv_compute" : 2048,
60+
"block_q" : 3024,
61+
"block_kv_compute" : 1024,
6262
"block_kv" : 2048,
63-
"block_q_dkv" : 2048,
63+
"block_q_dkv" : 3024,
6464
"block_kv_dkv" : 2048,
6565
"block_kv_dkv_compute" : 2048,
66-
"block_q_dq" : 2048,
66+
"block_q_dq" : 3024,
6767
"block_kv_dq" : 2048
6868
}
6969
# GroupNorm groups

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def run(config, pipeline=None, filename_prefix=""):
8080
slg_start=slg_start,
8181
slg_end=slg_end,
8282
)
83-
print("compile time: ", (time.perf_counter() - s0))
83+
print("generation time: ", (time.perf_counter() - s0))
8484

8585
s0 = time.perf_counter()
8686
if config.enable_profiler:

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -448,12 +448,20 @@ def __call__(
448448
return video
449449

450450

451-
@jax.jit
452-
def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask):
451+
@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale"))
452+
def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask, do_classifier_free_guidance, guidance_scale):
453453
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
454-
return wan_transformer(
454+
noise_pred = wan_transformer(
455455
hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask
456456
)
457+
if do_classifier_free_guidance:
458+
bsz = latents.shape[0] // 2
459+
noise_uncond = noise_pred[bsz:]
460+
noise_pred = noise_pred[:bsz]
461+
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
462+
latents = latents[:bsz]
463+
464+
return noise_pred, latents
457465

458466

459467
def run_inference(
@@ -480,13 +488,11 @@ def run_inference(
480488
if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps):
481489
slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True)
482490
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
483-
# get original batch size before concat in case of cfg.
484-
bsz = latents.shape[0]
485491
if do_classifier_free_guidance:
486492
latents = jnp.concatenate([latents] * 2)
487493
timestep = jnp.broadcast_to(t, latents.shape[0])
488494

489-
noise_pred = transformer_forward_pass(
495+
noise_pred, latents = transformer_forward_pass(
490496
graphdef,
491497
sharded_state,
492498
rest_of_state,
@@ -495,12 +501,9 @@ def run_inference(
495501
prompt_embeds,
496502
is_uncond=jnp.array(True, dtype=jnp.bool_),
497503
slg_mask=slg_mask,
504+
do_classifier_free_guidance=do_classifier_free_guidance,
505+
guidance_scale=guidance_scale
498506
)
499507

500-
if do_classifier_free_guidance:
501-
noise_uncond = noise_pred[bsz:]
502-
noise_pred = noise_pred[:bsz]
503-
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
504-
latents = latents[:bsz]
505508
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
506509
return latents

0 commit comments

Comments
 (0)