1+ import os
12import json
23import torch
34import jax
@@ -136,15 +137,22 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict,
136137 else :
137138 return load_base_wan_transformer (pretrained_model_name_or_path , eval_shapes , device , hf_download )
138139
139-
140140def load_base_wan_transformer (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
141141 device = jax .devices (device )[0 ]
142- with jax .default_device (device ):
143- if hf_download :
144- # download the index file for sharded models.
145- index_file_path = hf_hub_download (
146- pretrained_model_name_or_path , subfolder = "transformer" , filename = "diffusion_pytorch_model.safetensors.index.json"
147- )
142+ subfolder = "transformer"
143+ filename = "diffusion_pytorch_model.safetensors.index.json"
144+ local_files = False
145+ if os .path .isdir (pretrained_model_name_or_path ):
146+ index_file_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
147+ if not os .path .isfile (index_file_path ):
148+ raise FileNotFoundError (f"File { index_file_path } not found for local directory." )
149+ local_files = True
150+ elif hf_download :
151+ # download the index file for sharded models.
152+ index_file_path = hf_hub_download (
153+ pretrained_model_name_or_path , subfolder , filename ,
154+ )
155+ with jax .default_device (device ):
148156 # open the index file.
149157 with open (index_file_path , "r" ) as f :
150158 index_dict = json .load (f )
@@ -155,7 +163,10 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
155163 model_files = list (model_files )
156164 tensors = {}
157165 for model_file in model_files :
158- ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = "transformer" , filename = model_file )
166+ if local_files :
167+ ckpt_shard_path = os .path .join (pretrained_model_name_or_path , subfolder , model_file )
168+ else :
169+ ckpt_shard_path = hf_hub_download (pretrained_model_name_or_path , subfolder = "transformer" , filename = model_file )
159170 # now get all the filenames for the model that need downloading
160171 max_logging .log (f"Load and port Wan 2.1 transformer on { device } " )
161172
@@ -195,13 +206,18 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
195206
196207def load_wan_vae (pretrained_model_name_or_path : str , eval_shapes : dict , device : str , hf_download : bool = True ):
197208 device = jax .devices (device )[0 ]
209+ subfolder = "vae"
210+ filename = "diffusion_pytorch_model.safetensors"
211+ if os .path .isdir (pretrained_model_name_or_path ):
212+ ckpt_path = os .path .join (pretrained_model_name_or_path , subfolder , filename )
213+ if not os .path .isfile (ckpt_path ):
214+ raise FileNotFoundError (f"File { ckpt_path } not found for local directory." )
215+ elif hf_download :
216+ ckpt_path = hf_hub_download (
217+ pretrained_model_name_or_path , subfolder , filename
218+ )
219+ max_logging .log (f"Load and port Wan 2.1 VAE on { device } " )
198220 with jax .default_device (device ):
199- if hf_download :
200- ckpt_path = hf_hub_download (
201- pretrained_model_name_or_path , subfolder = "vae" , filename = "diffusion_pytorch_model.safetensors"
202- )
203- max_logging .log (f"Load and port Wan 2.1 VAE on { device } " )
204-
205221 if ckpt_path is not None :
206222 tensors = {}
207223 with safe_open (ckpt_path , framework = "pt" ) as f :
0 commit comments