@@ -41,7 +41,9 @@ def get_optimizer(name, net_params, optim_params):
4141def get_lr_scheduler (optimizer , sched_params ):
4242 name = sched_params ["lr_scheduler" ]
4343 val_it = sched_params ["iter_valid" ]
44- epoch_last = sched_params ["last_iter" ] / val_it
44+ epoch_last = sched_params ["last_iter" ]
45+ if (epoch_last > 0 ):
46+ epoch_last = int (epoch_last / val_it )
4547 if (name is None ):
4648 return None
4749 if (keyword_match (name , "ReduceLROnPlateau" )):
@@ -56,6 +58,11 @@ def get_lr_scheduler(optimizer, sched_params):
5658 lr_gamma = sched_params ["lr_gamma" ]
5759 scheduler = lr_scheduler .MultiStepLR (optimizer ,
5860 lr_milestones , lr_gamma , epoch_last )
61+ elif (keyword_match (name , "StepLR" )):
62+ lr_step = sched_params ["lr_step" ] / val_it
63+ lr_gamma = sched_params ["lr_gamma" ]
64+ scheduler = lr_scheduler .StepLR (optimizer ,
65+ lr_step , lr_gamma , epoch_last )
5966 elif (keyword_match (name , "CosineAnnealingLR" )):
6067 epoch_max = sched_params ["iter_max" ] / val_it
6168 lr_min = sched_params .get ("lr_min" , 0 )
0 commit comments