@@ -316,10 +316,12 @@ def train_once(
316316 flag_file_name = os .path .join (log_dir , f'flags_{ preemption_count } .json' )
317317 logging .info (f'Saving flags to { flag_file_name } .' )
318318 logger_utils .write_json (flag_file_name , flags .FLAGS .flag_values_dict ())
319- metrics_logger = logger_utils .set_up_loggers (log_dir ,
320- flags .FLAGS ,
321- hyperparameters )
322- workload .attach_metrics_logger (metrics_logger )
319+ metrics_logger = None
320+ if RANK == 0 :
321+ metrics_logger = logger_utils .set_up_loggers (log_dir ,
322+ flags .FLAGS ,
323+ hyperparameters )
324+ workload .attach_metrics_logger (metrics_logger )
323325
324326 global_start_time = get_time ()
325327 train_state ['last_step_end_time' ] = global_start_time
@@ -429,7 +431,7 @@ def train_once(
429431
430432 logging_start_time = get_time ()
431433
432- if log_dir is not None :
434+ if log_dir is not None and RANK == 0 :
433435 metrics_logger .append_scalar_metrics (
434436 latest_eval_result ,
435437 global_step = global_step ,
@@ -467,7 +469,7 @@ def train_once(
467469
468470 metrics = {'eval_results' : eval_results , 'global_step' : global_step }
469471
470- if log_dir is not None :
472+ if log_dir is not None and RANK == 0 :
471473 metrics_logger .append_scalar_metrics (
472474 {'score' : train_state ['accumulated_submission_time' ]},
473475 global_step = global_step ,
0 commit comments