@@ -448,12 +448,20 @@ def __call__(
448448 return video
449449
450450
451- @jax .jit
452- def transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , prompt_embeds , is_uncond , slg_mask ):
451+ @partial ( jax .jit , static_argnames = ( "do_classifier_free_guidance" , "guidance_scale" ))
452+ def transformer_forward_pass (graphdef , sharded_state , rest_of_state , latents , timestep , prompt_embeds , is_uncond , slg_mask , do_classifier_free_guidance , guidance_scale ):
453453 wan_transformer = nnx .merge (graphdef , sharded_state , rest_of_state )
454- return wan_transformer (
454+ noise_pred = wan_transformer (
455455 hidden_states = latents , timestep = timestep , encoder_hidden_states = prompt_embeds , is_uncond = is_uncond , slg_mask = slg_mask
456456 )
457+ if do_classifier_free_guidance :
458+ bsz = latents .shape [0 ] // 2
459+ noise_uncond = noise_pred [bsz :]
460+ noise_pred = noise_pred [:bsz ]
461+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
462+ latents = latents [:bsz ]
463+
464+ return noise_pred , latents
457465
458466
459467def run_inference (
@@ -480,13 +488,11 @@ def run_inference(
480488 if slg_layers and int (slg_start * num_inference_steps ) <= step < int (slg_end * num_inference_steps ):
481489 slg_mask = slg_mask .at [jnp .array (slg_layers )].set (True )
482490 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
483- # get original batch size before concat in case of cfg.
484- bsz = latents .shape [0 ]
485491 if do_classifier_free_guidance :
486492 latents = jnp .concatenate ([latents ] * 2 )
487493 timestep = jnp .broadcast_to (t , latents .shape [0 ])
488494
489- noise_pred = transformer_forward_pass (
495+ noise_pred , latents = transformer_forward_pass (
490496 graphdef ,
491497 sharded_state ,
492498 rest_of_state ,
@@ -495,12 +501,9 @@ def run_inference(
495501 prompt_embeds ,
496502 is_uncond = jnp .array (True , dtype = jnp .bool_ ),
497503 slg_mask = slg_mask ,
504+ do_classifier_free_guidance = do_classifier_free_guidance ,
505+ guidance_scale = guidance_scale
498506 )
499507
500- if do_classifier_free_guidance :
501- noise_uncond = noise_pred [bsz :]
502- noise_pred = noise_pred [:bsz ]
503- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond )
504- latents = latents [:bsz ]
505508 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
506509 return latents
0 commit comments