@@ -289,22 +289,22 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
289289 device = samples .device
290290
291291 prompt_sizes = torch .tensor ([prompt_tensors .shape [1 ]] * len (prompt_tensors ), device = device )
292- padded_samples = self .accelerator .pad_across_processes (
293- samples , dim = 1 , pad_index = self .tokenizer .eos_token_id , pad_first = False
294- )
295- padded_prompts = self .accelerator .pad_across_processes (
296- prompt_tensors , dim = 1 , pad_index = self .tokenizer .eos_token_id , pad_first = False
297- )
298-
299292 metadata = {k : v for k , v in batch .items () if k != "input_ids" and k != "attention_mask" }
293+
300294 if self .config .train .reward_only_in_main_process :
295+ padded_samples = self .accelerator .pad_across_processes (
296+ samples , dim = 1 , pad_index = self .tokenizer .eos_token_id , pad_first = False
297+ )
298+ padded_prompts = self .accelerator .pad_across_processes (
299+ prompt_tensors , dim = 1 , pad_index = self .tokenizer .eos_token_id , pad_first = False
300+ )
301301 gathered_samples = self .accelerator .gather (padded_samples )
302302 gathered_prompts = self .accelerator .gather (padded_prompts )
303303 gathered_prompt_sizes = self .accelerator .gather (prompt_sizes )
304304 metadata = gather_dict (metadata )
305305 else :
306- gathered_samples = padded_samples
307- gathered_prompts = padded_prompts
306+ gathered_samples = samples
307+ gathered_prompts = prompt_tensors
308308 gathered_prompt_sizes = prompt_sizes
309309
310310 if not self .config .train .reward_only_in_main_process or self .accelerator .is_main_process :
0 commit comments