|
25 | 25 | from ...pyconfig import HyperParameters |
26 | 26 | from ... import max_logging |
27 | 27 | from ... import max_utils |
28 | | -from ...max_utils import get_flash_block_sizes, get_precision |
| 28 | +from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated |
29 | 29 | from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae |
30 | 30 | from ...models.wan.transformers.transformer_wan import WanModel |
31 | 31 | from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache |
@@ -99,7 +99,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): |
99 | 99 | params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
100 | 100 | for path, val in flax.traverse_util.flatten_dict(params).items(): |
101 | 101 | sharding = logical_state_sharding[path].value |
102 | | - state[path].value = jax.device_put(val, sharding) |
| 102 | + state[path].value = device_put_replicated(val, sharding) |
103 | 103 | state = nnx.from_flat_state(state) |
104 | 104 |
|
105 | 105 | wan_transformer = nnx.merge(graphdef, state, rest_of_state) |
@@ -183,27 +183,41 @@ def load_tokenizer(cls, config: HyperParameters): |
183 | 183 |
|
184 | 184 | @classmethod |
185 | 185 | def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): |
186 | | - wan_vae = AutoencoderKLWan.from_config( |
187 | | - config.pretrained_model_name_or_path, |
188 | | - subfolder="vae", |
189 | | - rngs=rngs, |
190 | | - mesh=mesh, |
191 | | - dtype=config.activations_dtype, |
192 | | - weights_dtype=config.weights_dtype, |
193 | | - ) |
194 | | - vae_cache = AutoencoderKLWanCache(wan_vae) |
195 | | - |
| 186 | + |
| 187 | + def create_model(rngs: nnx.Rngs, config: HyperParameters): |
| 188 | + wan_vae = AutoencoderKLWan.from_config( |
| 189 | + config.pretrained_model_name_or_path, |
| 190 | + subfolder="vae", |
| 191 | + rngs=rngs, |
| 192 | + mesh=mesh, |
| 193 | + dtype=config.activations_dtype, |
| 194 | + weights_dtype=config.weights_dtype, |
| 195 | + ) |
| 196 | + return wan_vae |
| 197 | + # 1. eval shape |
| 198 | + p_model_factory = partial(create_model, config=config) |
| 199 | + wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) |
196 | 200 | graphdef, state = nnx.split(wan_vae, nnx.Param) |
| 201 | + |
| 202 | + # 2. retrieve the state shardings, mapping logical names to mesh axis names. |
| 203 | + logical_state_spec = nnx.get_partition_spec(state) |
| 204 | + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) |
| 205 | + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) |
197 | 206 | params = state.to_pure_dict() |
198 | | - # This replaces random params with the model. |
| 207 | + state = dict(nnx.to_flat_state(state)) |
| 208 | + |
| 209 | + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. |
| 210 | + # This helps with loading sharded weights directly into the accelerators without fist copying them |
| 211 | + # all to one device and then distributing them, thus using low HBM memory. |
199 | 212 | params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") |
200 | 213 | params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
201 | | - params = jax.device_put(params, NamedSharding(mesh, P())) |
202 | | - wan_vae = nnx.merge(graphdef, params) |
203 | | - p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) |
204 | | - # Shard |
205 | | - with mesh: |
206 | | - wan_vae = p_create_sharded_logical_model(model=wan_vae) |
| 214 | + for path, val in flax.traverse_util.flatten_dict(params).items(): |
| 215 | + sharding = logical_state_sharding[path].value |
| 216 | + state[path].value = device_put_replicated(val, sharding) |
| 217 | + state = nnx.from_flat_state(state) |
| 218 | + |
| 219 | + wan_vae = nnx.merge(graphdef, state) |
| 220 | + vae_cache = AutoencoderKLWanCache(wan_vae) |
207 | 221 | return wan_vae, vae_cache |
208 | 222 |
|
209 | 223 | @classmethod |
|
0 commit comments