@@ -157,9 +157,6 @@ def training(self):
157157 loss = self .get_loss_value (data , outputs , labels )
158158 loss .backward ()
159159 self .optimizer .step ()
160- if (self .scheduler is not None and \
161- not isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
162- self .scheduler .step ()
163160
164161 # statistics
165162 sample_num += labels .size (0 )
@@ -183,7 +180,7 @@ def validation(self):
183180 inputs = self .convert_tensor_type (data ['image' ])
184181 labels = self .convert_tensor_type (data ['label_prob' ])
185182 inputs , labels = inputs .to (self .device ), labels .to (self .device )
186- self .optimizer .zero_grad ()
183+ # self.optimizer.zero_grad()
187184 # forward + backward + optimize
188185 outputs = self .net (inputs )
189186 loss = self .get_loss_value (data , outputs , labels )
@@ -196,20 +193,17 @@ def validation(self):
196193 avg_loss = running_loss / sample_num
197194 avg_score = running_score .double () / sample_num
198195 metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
199- if (isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
200- self .scheduler .step (avg_score )
201196 valid_scalers = {'loss' : avg_loss , metrics : avg_score }
202197 return valid_scalers
203198
204199 def write_scalars (self , train_scalars , valid_scalars , lr_value , glob_it ):
205- metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
200+ metrics = self .config ['training' ].get ("evaluation_metric" , "accuracy" )
206201 loss_scalar = {'train' :train_scalars ['loss' ], 'valid' :valid_scalars ['loss' ]}
207202 acc_scalar = {'train' :train_scalars [metrics ],'valid' :valid_scalars [metrics ]}
208203 self .summ_writer .add_scalars ('loss' , loss_scalar , glob_it )
209204 self .summ_writer .add_scalars (metrics , acc_scalar , glob_it )
210205 self .summ_writer .add_scalars ('lr' , {"lr" : lr_value }, glob_it )
211206
212- logging .info ("{0:} it {1:}" .format (str (datetime .now ())[:- 7 ], glob_it ))
213207 logging .info ('train loss {0:.4f}, avg {1:} {2:.4f}' .format (
214208 train_scalars ['loss' ], metrics , train_scalars [metrics ]))
215209 logging .info ('valid loss {0:.4f}, avg {1:} {2:.4f}' .format (
@@ -251,7 +245,10 @@ def train_valid(self):
251245 checkpoint_file = "{0:}/{1:}_{2:}.pt" .format (ckpt_dir , ckpt_prefix , iter_start )
252246 self .checkpoint = torch .load (checkpoint_file , map_location = self .device )
253247 assert (self .checkpoint ['iteration' ] == iter_start )
254- self .net .load_state_dict (self .checkpoint ['model_state_dict' ])
248+ if (len (device_ids ) > 1 ):
249+ self .net .module .load_state_dict (self .checkpoint ['model_state_dict' ])
250+ else :
251+ self .net .load_state_dict (self .checkpoint ['model_state_dict' ])
255252 self .max_val_score = self .checkpoint .get ('valid_pred' , 0 )
256253 self .max_val_it = self .checkpoint ['iteration' ]
257254 self .best_model_wts = self .checkpoint ['model_state_dict' ]
@@ -266,15 +263,28 @@ def train_valid(self):
266263 self .glob_it = iter_start
267264 for it in range (iter_start , iter_max , iter_valid ):
268265 lr_value = self .optimizer .param_groups [0 ]['lr' ]
266+ t0 = time .time ()
269267 train_scalars = self .training ()
268+ t1 = time .time ()
270269 valid_scalars = self .validation ()
270+ t2 = time .time ()
271+ if (isinstance (self .scheduler , lr_scheduler .ReduceLROnPlateau )):
272+ self .scheduler .step (valid_scalars [metrics ])
273+ else :
274+ self .scheduler .step ()
275+
271276 self .glob_it = it + iter_valid
277+ logging .info ("\n {0:} it {1:}" .format (str (datetime .now ())[:- 7 ], self .glob_it ))
278+ logging .info ('learning rate {0:}' .format (lr_value ))
279+ logging .info ("training/validation time: {0:.2f}s/{1:.2f}s" .format (t1 - t0 , t2 - t1 ))
272280 self .write_scalars (train_scalars , valid_scalars , lr_value , self .glob_it )
273-
274281 if (valid_scalars [metrics ] > self .max_val_score ):
275282 self .max_val_score = valid_scalars [metrics ]
276283 self .max_val_it = self .glob_it
277- self .best_model_wts = copy .deepcopy (self .net .state_dict ())
284+ if (len (device_ids ) > 1 ):
285+ self .best_model_wts = copy .deepcopy (self .net .module .state_dict ())
286+ else :
287+ self .best_model_wts = copy .deepcopy (self .net .state_dict ())
278288
279289 stop_now = True if (early_stop_it is not None and \
280290 self .glob_it - self .max_val_it > early_stop_it ) else False
@@ -306,7 +316,6 @@ def train_valid(self):
306316 self .max_val_it , metrics , self .max_val_score ))
307317 self .summ_writer .close ()
308318
309-
310319 def infer (self ):
311320 device_ids = self .config ['testing' ]['gpus' ]
312321 device = torch .device ("cuda:{0:}" .format (device_ids [0 ]))
0 commit comments