Skip to content

Commit f30daac

Browse files
ninatumartinarroyo
andcommitted
Wan training: Fix WAN training timestep sampling with continuous sampling and introduce disable_training_weights, add max_grad_norm and max_abs_grad logging.
- Switched timestamp sampling from discrete to continuous. - Add max_grad_norm and max_abs_grad calculation and logging. - Introduced `config.disable_training_weights` to optionally disable mid-point loss weighting. Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent efbc91d commit f30daac

4 files changed

Lines changed: 91 additions & 24 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ output_dir: 'sdxl-model-finetuned'
282282
per_device_batch_size: 1.0
283283
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
284284
global_batch_size: 0
285+
disable_training_weights: False # if True, disables the use of mid-point loss weighting
285286

286287
# For creating tfrecords from dataset
287288
tfrecords_dir: ''

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ output_dir: 'sdxl-model-finetuned'
238238
per_device_batch_size: 1.0
239239
# If global_batch_size % jax.device_count is not 0, use FSDP sharding.
240240
global_batch_size: 0
241+
disable_training_weights: False # if True, disables the use of mid-point loss weighting
241242

242243
# For creating tfrecords from dataset
243244
tfrecords_dir: ''

src/maxdiffusion/schedulers/scheduling_flow_match_flax.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,9 @@ def set_timesteps(
150150

151151
linear_timesteps_weights = None
152152
if training:
153-
x = timesteps
154-
y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
155-
y_shifted = y - jnp.min(y)
156-
bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted))
157-
linear_timesteps_weights = bsmntw_weighing
153+
linear_timesteps_weights = self._calculate_training_weights(
154+
timesteps, num_inference_steps
155+
)
158156

159157
return state.replace(
160158
sigmas=sigmas,
@@ -164,6 +162,56 @@ def set_timesteps(
164162
num_inference_steps=num_inference_steps,
165163
)
166164

165+
def _calculate_training_weights(
166+
self, timesteps: jnp.ndarray, num_inference_steps: int
167+
) -> jnp.ndarray:
168+
"""Calculates the training weight for a given timestep."""
169+
x = timesteps
170+
y = jnp.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
171+
y_shifted = y - jnp.min(y)
172+
bsmntw_weighing = y_shifted * (num_inference_steps / jnp.sum(y_shifted))
173+
linear_timesteps_weights = bsmntw_weighing
174+
return linear_timesteps_weights
175+
176+
def sample_timesteps(self, timestep_rng, batch_size):
177+
# 1. Sample continuous timesteps t in [0, 1]
178+
t = jax.random.uniform(timestep_rng, (batch_size,))
179+
180+
# 2. Apply the "Shift" weighting (Time shifting)
181+
t_shifted = (t * self.config.shift) / (1 + (self.config.shift - 1) * t)
182+
183+
# 3. Scale t to [0, self.config.num_train_timesteps]
184+
timesteps = t_shifted.squeeze() * self.config.num_train_timesteps
185+
186+
return timesteps
187+
188+
def apply_flow_match(
189+
self,
190+
noise: jnp.ndarray,
191+
batch_images: jnp.ndarray,
192+
timesteps: jnp.ndarray,
193+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
194+
"""Apply flow match to the batch of images.
195+
196+
Replaces: scheduler.add_noise + scheduler.training_target +
197+
scheduler.training_weight
198+
"""
199+
200+
t = timesteps.astype(jnp.float32) / self.config.num_train_timesteps
201+
broadcast_shape = (-1,) + (1,) * (batch_images.ndim - 1)
202+
t = t.reshape(broadcast_shape)
203+
204+
sigma = (1 - t) * self.config.sigma_min + t * self.config.sigma_max
205+
206+
noisy_latents = (1 - sigma) * batch_images + sigma * noise
207+
target = noise - batch_images
208+
209+
training_weights = self._calculate_training_weights(
210+
timesteps, self.config.num_train_timesteps
211+
)
212+
213+
return noisy_latents, target, training_weights
214+
167215
def _find_timestep_id(self, state: FlowMatchSchedulerState, timestep: jnp.ndarray) -> jnp.ndarray:
168216
"""Finds the index of the closest timestep in the scheduler's `timesteps` array."""
169217
timestep = jnp.asarray(timestep, dtype=state.timesteps.dtype)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import tensorflow as tf
2525
import jax.numpy as jnp
2626
import jax
27+
import jaxopt
2728
from jax.sharding import PartitionSpec as P
2829
from flax import nnx
2930
from maxdiffusion.schedulers import FlaxFlowMatchScheduler
@@ -453,38 +454,53 @@ def loss_fn(params):
453454
model = nnx.merge(state.graphdef, params, state.rest_of_state)
454455
latents = data["latents"].astype(config.weights_dtype)
455456
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
457+
456458
bsz = latents.shape[0]
457-
timesteps = jax.random.randint(
458-
timestep_rng,
459-
(bsz,),
460-
0,
461-
scheduler.config.num_train_timesteps,
459+
timesteps = scheduler.sample_timesteps(timestep_rng, bsz)
460+
noise = jax.random.normal(
461+
key=new_rng, shape=latents.shape, dtype=latents.dtype
462+
)
463+
noisy_latents, training_target, training_weight = (
464+
scheduler.apply_flow_match(noise, latents, timesteps)
462465
)
463-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
464-
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
465-
466466
with jax.named_scope("forward_pass"):
467467
model_pred = model(
468468
hidden_states=noisy_latents,
469469
timestep=timesteps,
470470
encoder_hidden_states=encoder_hidden_states,
471471
deterministic=False,
472-
rngs=nnx.Rngs(dropout_rng),
472+
rngs=nnx.Rngs(dropout=dropout_rng),
473473
)
474474

475475
with jax.named_scope("loss"):
476-
training_target = scheduler.training_target(latents, noise, timesteps)
477-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
478476
loss = (training_target - model_pred) ** 2
479-
loss = loss * training_weight
477+
if not config.disable_training_weights:
478+
training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4))
479+
loss = loss * training_weight
480480
loss = jnp.mean(loss)
481481

482482
return loss
483483

484484
grad_fn = nnx.value_and_grad(loss_fn)
485485
loss, grads = grad_fn(state.params)
486+
max_grad_norm = jaxopt.tree_util.tree_l2_norm(grads)
487+
488+
max_abs_grad = jax.tree_util.tree_reduce(
489+
lambda max_val, arr: jnp.maximum(max_val, jnp.max(jnp.abs(arr))),
490+
grads,
491+
initializer=-1.0,
492+
)
493+
494+
metrics = {
495+
"scalar": {
496+
"learning/loss": loss,
497+
"learning/max_grad_norm": max_grad_norm,
498+
"learning/max_abs_grad": max_abs_grad,
499+
},
500+
"scalars": {},
501+
}
502+
486503
new_state = state.apply_gradients(grads=grads)
487-
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
488504
return new_state, scheduler_state, metrics, new_rng
489505

490506

@@ -495,14 +511,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
495511

496512
# The loss function logic is identical to training. We are evaluating the model's
497513
# ability to perform its core training objective (e.g., denoising).
498-
@jax.jit
499514
def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
500515
# Reconstruct the model from its definition and parameters
501516
model = nnx.merge(state.graphdef, params, state.rest_of_state)
502517

503518
noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype)
504-
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
505-
519+
noisy_latents, training_target, training_weight = (
520+
scheduler.apply_flow_match(noise, latents, timesteps)
521+
)
506522
# Get the model's prediction
507523
model_pred = model(
508524
hidden_states=noisy_latents,
@@ -512,10 +528,11 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
512528
)
513529

514530
# Calculate the loss against the target
515-
training_target = scheduler.training_target(latents, noise, timesteps)
516-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
517531
loss = (training_target - model_pred) ** 2
518-
loss = loss * training_weight
532+
if not config.disable_training_weights:
533+
training_weight = jnp.expand_dims(training_weight, axis=(1, 2, 3, 4))
534+
loss = loss * training_weight
535+
519536
# Calculate the mean loss per sample across all non-batch dimensions.
520537
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)
521538

0 commit comments

Comments
 (0)