Skip to content

Commit c292e3d

Browse files
committed
support parallel reward function
1 parent 1a91891 commit c292e3d

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
289289
device = samples.device
290290

291291
prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
292-
padded_samples = self.accelerator.pad_across_processes(
293-
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
294-
)
295-
padded_prompts = self.accelerator.pad_across_processes(
296-
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
297-
)
298-
299292
metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}
293+
300294
if self.config.train.reward_only_in_main_process:
295+
padded_samples = self.accelerator.pad_across_processes(
296+
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
297+
)
298+
padded_prompts = self.accelerator.pad_across_processes(
299+
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
300+
)
301301
gathered_samples = self.accelerator.gather(padded_samples)
302302
gathered_prompts = self.accelerator.gather(padded_prompts)
303303
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
304304
metadata = gather_dict(metadata)
305305
else:
306-
gathered_samples = padded_samples
307-
gathered_prompts = padded_prompts
306+
gathered_samples = samples
307+
gathered_prompts = prompt_tensors
308308
gathered_prompt_sizes = prompt_sizes
309309

310310
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:

0 commit comments

Comments
 (0)