Skip to content

Commit 6d3aef3

Browse files
committed
Lint fix
1 parent 9518a70 commit 6d3aef3

2 files changed

Lines changed: 15 additions & 11 deletions

File tree

algorithmic_efficiency/checkpoint_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def maybe_restore_checkpoint(framework: str,
119119

120120
else:
121121
checkpoint_state = latest_ckpt
122-
if isinstance(model_params, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
122+
if isinstance(
123+
model_params,
124+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
123125
model_params = model_params.module
124126
model_params.load_state_dict(checkpoint_state['model_params'])
125127
checkpoint_state['model_params'] = model_params
@@ -196,7 +198,9 @@ def save_checkpoint(framework: str,
196198
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
197199
model_state = jax.device_get(jax_utils.unreplicate(model_state))
198200
else:
199-
if isinstance(model_params, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
201+
if isinstance(
202+
model_params,
203+
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
200204
model_params = model_params.module
201205
model_params = model_params.state_dict()
202206
optimizer_state_dict = {}

submission_runner.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def train_once(
319319
metrics_logger = None
320320
if RANK == 0:
321321
metrics_logger = logger_utils.set_up_loggers(log_dir,
322-
flags.FLAGS,
323-
hyperparameters)
322+
flags.FLAGS,
323+
hyperparameters)
324324
workload.attach_metrics_logger(metrics_logger)
325325

326326
global_start_time = get_time()
@@ -470,13 +470,13 @@ def train_once(
470470
metrics = {'eval_results': eval_results, 'global_step': global_step}
471471

472472
if log_dir is not None and RANK == 0:
473-
metrics_logger.append_scalar_metrics(
474-
{'score': train_state['accumulated_submission_time']},
475-
global_step=global_step,
476-
preemption_count=preemption_count)
477-
metrics_logger.finish()
478-
if save_checkpoints:
479-
checkpoint_utils.save_checkpoint(
473+
metrics_logger.append_scalar_metrics(
474+
{'score': train_state['accumulated_submission_time']},
475+
global_step=global_step,
476+
preemption_count=preemption_count)
477+
metrics_logger.finish()
478+
if save_checkpoints:
479+
checkpoint_utils.save_checkpoint(
480480
framework=FLAGS.framework,
481481
optimizer_state=optimizer_state,
482482
model_params=model_params,

0 commit comments

Comments
 (0)