Skip to content

Commit b219048

Browse files
replace device_put with replicated for multi host.
1 parent 858e168 commit b219048

1 file changed

Lines changed: 33 additions & 19 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...pyconfig import HyperParameters
2626
from ... import max_logging
2727
from ... import max_utils
28-
from ...max_utils import get_flash_block_sizes, get_precision
28+
from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated
2929
from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae
3030
from ...models.wan.transformers.transformer_wan import WanModel
3131
from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache
@@ -99,7 +99,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
9999
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
100100
for path, val in flax.traverse_util.flatten_dict(params).items():
101101
sharding = logical_state_sharding[path].value
102-
state[path].value = jax.device_put(val, sharding)
102+
state[path].value = device_put_replicated(val, sharding)
103103
state = nnx.from_flat_state(state)
104104

105105
wan_transformer = nnx.merge(graphdef, state, rest_of_state)
@@ -183,27 +183,41 @@ def load_tokenizer(cls, config: HyperParameters):
183183

184184
@classmethod
185185
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):
186-
wan_vae = AutoencoderKLWan.from_config(
187-
config.pretrained_model_name_or_path,
188-
subfolder="vae",
189-
rngs=rngs,
190-
mesh=mesh,
191-
dtype=config.activations_dtype,
192-
weights_dtype=config.weights_dtype,
193-
)
194-
vae_cache = AutoencoderKLWanCache(wan_vae)
195-
186+
187+
def create_model(rngs: nnx.Rngs, config: HyperParameters):
188+
wan_vae = AutoencoderKLWan.from_config(
189+
config.pretrained_model_name_or_path,
190+
subfolder="vae",
191+
rngs=rngs,
192+
mesh=mesh,
193+
dtype=config.activations_dtype,
194+
weights_dtype=config.weights_dtype,
195+
)
196+
return wan_vae
197+
# 1. eval shape
198+
p_model_factory = partial(create_model, config=config)
199+
wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs)
196200
graphdef, state = nnx.split(wan_vae, nnx.Param)
201+
202+
# 2. retrieve the state shardings, mapping logical names to mesh axis names.
203+
logical_state_spec = nnx.get_partition_spec(state)
204+
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
205+
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
197206
params = state.to_pure_dict()
198-
# This replaces random params with the model.
207+
state = dict(nnx.to_flat_state(state))
208+
209+
# 4. Load pretrained weights and move them to device using the state shardings from (3) above.
210+
# This helps with loading sharded weights directly into the accelerators without fist copying them
211+
# all to one device and then distributing them, thus using low HBM memory.
199212
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
200213
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
201-
params = jax.device_put(params, NamedSharding(mesh, P()))
202-
wan_vae = nnx.merge(graphdef, params)
203-
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
204-
# Shard
205-
with mesh:
206-
wan_vae = p_create_sharded_logical_model(model=wan_vae)
214+
for path, val in flax.traverse_util.flatten_dict(params).items():
215+
sharding = logical_state_sharding[path].value
216+
state[path].value = device_put_replicated(val, sharding)
217+
state = nnx.from_flat_state(state)
218+
219+
wan_vae = nnx.merge(graphdef, state)
220+
vae_cache = AutoencoderKLWanCache(wan_vae)
207221
return wan_vae, vae_cache
208222

209223
@classmethod

0 commit comments

Comments
 (0)