Skip to content

Commit 3d83278

Browse files
jouwhwchen2017raviguptaamd
authored
fix init weights issue for critic/reward model (#983)
* Add file extension (#980) Signed-off-by: Hongwei Chen <hongweichen@microsoft.com> Signed-off-by: jouw <jouw@foxmail.com> * fix init weights issue for critic/reward model Signed-off-by: jouw <jouw@foxmail.com> * Update submodule link to reflect https style (#981) Signed-off-by: raviguptaamd <ravi.gupta@amd.com> Signed-off-by: jouw <jouw@foxmail.com> * fix formatting issue Signed-off-by: jouw <jouw@foxmail.com> --------- Signed-off-by: Hongwei Chen <hongweichen@microsoft.com> Signed-off-by: jouw <jouw@foxmail.com> Signed-off-by: raviguptaamd <ravi.gupta@amd.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> Co-authored-by: raviguptaamd <ravi.gupta@amd.com>
1 parent 4579df3 commit 3d83278

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

applications/DeepSpeed-Chat/dschat/utils/model/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from huggingface_hub import snapshot_download
1313
from transformers.integrations.deepspeed import HfDeepSpeedConfig
14+
from transformers.modeling_utils import no_init_weights
1415

1516
from dschat.utils.model.reward_model import RewardModel
1617
from dschat.utils.utils import load_state_dict_into_model, print_rank_0
@@ -99,7 +100,8 @@ def create_hf_model(model_class,
99100
dschf = None
100101
if rlhf_training:
101102
# the weight loading is handled by create critic model
102-
model = model_class.from_config(model_config)
103+
with no_init_weights():
104+
model = model_class.from_config(model_config)
103105
else:
104106
model = model_class.from_pretrained(
105107
model_name_or_path,

0 commit comments

Comments
 (0)