Skip to content

Commit 99145ff

Browse files
committed
update lr_scheduler
add StepLR
1 parent ab07d4f commit 99145ff

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

pymic/net_run/get_optimizer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def get_optimizer(name, net_params, optim_params):
4141
def get_lr_scheduler(optimizer, sched_params):
4242
name = sched_params["lr_scheduler"]
4343
val_it = sched_params["iter_valid"]
44-
epoch_last = sched_params["last_iter"] / val_it
44+
epoch_last = sched_params["last_iter"]
45+
if(epoch_last > 0):
46+
epoch_last = int(epoch_last / val_it)
4547
if(name is None):
4648
return None
4749
if(keyword_match(name, "ReduceLROnPlateau")):
@@ -56,6 +58,11 @@ def get_lr_scheduler(optimizer, sched_params):
5658
lr_gamma = sched_params["lr_gamma"]
5759
scheduler = lr_scheduler.MultiStepLR(optimizer,
5860
lr_milestones, lr_gamma, epoch_last)
61+
elif(keyword_match(name, "StepLR")):
62+
lr_step = sched_params["lr_step"] / val_it
63+
lr_gamma = sched_params["lr_gamma"]
64+
scheduler = lr_scheduler.StepLR(optimizer,
65+
lr_step, lr_gamma, epoch_last)
5966
elif(keyword_match(name, "CosineAnnealingLR")):
6067
epoch_max = sched_params["iter_max"] / val_it
6168
lr_min = sched_params.get("lr_min", 0)

0 commit comments

Comments
 (0)