Skip to content

Commit efbc91d

Browse files
ninatumartinarroyo
andcommitted
Wan 2.1 training: Resolve checkpoint loading issues with larger TPU slices and different topologies
Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent 6101386 commit efbc91d

4 files changed

Lines changed: 95 additions & 21 deletions

File tree

src/maxdiffusion/checkpointing/wan_checkpointer_2_1.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_2_1 import WanPipeline2_1
2627

2728

2829
class WanCheckpointer2_1(WanCheckpointer):
@@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
return jax.ShapeDtypeStruct(
49+
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
50+
)
51+
52+
target_shardings = jax.tree_util.tree_map(
53+
lambda x: replicated_sharding, state
54+
)
55+
56+
with mesh:
57+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
58+
add_sharding_to_struct, state, target_shardings
59+
)
60+
4161
params_restore = ocp.args.PyTreeRestore(
4262
restore_args=jax.tree.map(
4363
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
64+
abstract_train_state_with_sharding,
4565
)
4666
)
4767

src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p1.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616

1717
import json
18-
import jax
19-
import numpy as np
2018
from typing import Optional, Tuple
21-
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
22-
from .. import max_logging
23-
import orbax.checkpoint as ocp
2419
from etils import epath
20+
import jax
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2522
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer
23+
import numpy as np
24+
import orbax.checkpoint as ocp
25+
from .. import max_logging
26+
from ..pipelines.wan.wan_pipeline_i2v_2p1 import WanPipelineI2V_2_1
2627

2728

2829
class WanCheckpointerI2V_2_1(WanCheckpointer):
@@ -35,13 +36,32 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
3536
max_logging.log("No WAN checkpoint found.")
3637
return None, None
3738
max_logging.log(f"Loading WAN checkpoint from step {step}")
39+
40+
cpu_devices = np.array(jax.devices(backend="cpu"))
41+
mesh = Mesh(cpu_devices, axis_names=("data",))
42+
replicated_sharding = NamedSharding(mesh, P())
43+
3844
metadatas = self.checkpoint_manager.item_metadata(step)
39-
transformer_metadata = metadatas.wan_state
40-
abstract_tree_structure_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, transformer_metadata)
45+
state = metadatas.wan_state
46+
47+
def add_sharding_to_struct(leaf_struct, sharding):
48+
return jax.ShapeDtypeStruct(
49+
shape=leaf_struct.shape, dtype=leaf_struct.dtype, sharding=sharding
50+
)
51+
52+
target_shardings = jax.tree_util.tree_map(
53+
lambda x: replicated_sharding, state
54+
)
55+
56+
with mesh:
57+
abstract_train_state_with_sharding = jax.tree_util.tree_map(
58+
add_sharding_to_struct, state, target_shardings
59+
)
60+
4161
params_restore = ocp.args.PyTreeRestore(
4262
restore_args=jax.tree.map(
4363
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
44-
abstract_tree_structure_params,
64+
abstract_train_state_with_sharding,
4565
)
4666
)
4767

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,26 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
168168
)
169169
for path, val in flax.traverse_util.flatten_dict(params).items():
170170
if restored_checkpoint:
171-
path = path[:-1]
171+
if path[-1] == "value":
172+
path = path[:-1] # remove 'value'
173+
174+
try:
175+
# Convert block indices to integers, as they might have been loaded as strings from the checkpoint.
176+
path = path[:1] + (int(path[1]),) + path[2:]
177+
except Exception:
178+
pass
179+
172180
sharding = logical_state_sharding[path].value
173-
state[path].value = device_put_replicated(val, sharding)
181+
try:
182+
state[path].value = device_put_replicated(val, sharding)
183+
except Exception as e:
184+
max_logging.log(f"Failed to device_put_replicated {path}: {e}")
185+
max_logging.log(f"Trying to use process_allgather for {path}")
186+
val_on_host = jax.experimental.multihost_utils.process_allgather(
187+
val, tiled=True
188+
)
189+
state[path].value = device_put_replicated(val_on_host, sharding)
190+
del val_on_host
174191
state = nnx.from_flat_state(state)
175192

176193
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
@@ -470,7 +487,6 @@ def encode_prompt(
470487
negative_prompt_embeds: jax.Array = None,
471488
):
472489
prompt = [prompt] if isinstance(prompt, str) else prompt
473-
batch_size = len(prompt)
474490
if prompt_embeds is None:
475491
prompt_embeds = self._get_t5_prompt_embeds(
476492
prompt=prompt,
@@ -480,6 +496,7 @@ def encode_prompt(
480496
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
481497

482498
if negative_prompt_embeds is None:
499+
batch_size = len(prompt_embeds)
483500
negative_prompt = negative_prompt or ""
484501
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
485502
negative_prompt_embeds = self._get_t5_prompt_embeds(

src/maxdiffusion/pipelines/wan/wan_vace_pipeline_2_1.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,26 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
119119
)
120120
for path, val in flax.traverse_util.flatten_dict(params).items():
121121
if restored_checkpoint:
122-
path = path[:-1]
122+
if path[-1] == "value":
123+
path = path[:-1] # remove 'value'
124+
125+
try:
126+
# Convert block indices to integers, as they might have been loaded as strings from the checkpoint.
127+
path = path[:1] + (int(path[1]),) + path[2:]
128+
except Exception:
129+
pass
130+
123131
sharding = logical_state_sharding[path].value
124-
state[path].value = device_put_replicated(val, sharding)
132+
try:
133+
state[path].value = device_put_replicated(val, sharding)
134+
except Exception as e:
135+
max_logging.log(f"Failed to device_put_replicated {path}: {e}")
136+
max_logging.log(f"Trying to use process_allgather for {path}")
137+
val_on_host = jax.experimental.multihost_utils.process_allgather(
138+
val, tiled=True
139+
)
140+
state[path].value = device_put_replicated(val_on_host, sharding)
141+
del val_on_host
125142
state = nnx.from_flat_state(state)
126143

127144
wan_transformer = nnx.merge(graphdef, state, rest_of_state)

0 commit comments

Comments
 (0)