Skip to content

Commit 49e1c64

Browse files
committed
update main files
lr_schduler update after each validation for classification task; update main files for logging bug update init fix bug for alpha in mean teacher
1 parent 99145ff commit 49e1c64

14 files changed

Lines changed: 347 additions & 38 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ BibTeX entry:
1515
author = {Guotai Wang and Xiangde Luo and Ran Gu and Shuojue Yang and Yijie Qu and Shuwei Zhai and Qianfei Zhao and Kang Li and Shaoting Zhang},
1616
title = {{PyMIC: A deep learning toolkit for annotation-efficient medical image segmentation}},
1717
year = {2023},
18-
url = {http://arxiv.org/abs/2208.09350},
18+
url = {https://doi.org/10.1016/j.cmpb.2023.107398},
1919
journal = {Computer Methods and Programs in Biomedicine},
20-
volume = {February},
20+
volume = {231},
2121
pages = {107398},
2222
}
2323

pymic/net_run/agent_cls.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,6 @@ def training(self):
157157
loss = self.get_loss_value(data, outputs, labels)
158158
loss.backward()
159159
self.optimizer.step()
160-
if(self.scheduler is not None and \
161-
not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
162-
self.scheduler.step()
163160

164161
# statistics
165162
sample_num += labels.size(0)
@@ -183,7 +180,7 @@ def validation(self):
183180
inputs = self.convert_tensor_type(data['image'])
184181
labels = self.convert_tensor_type(data['label_prob'])
185182
inputs, labels = inputs.to(self.device), labels.to(self.device)
186-
self.optimizer.zero_grad()
183+
# self.optimizer.zero_grad()
187184
# forward + backward + optimize
188185
outputs = self.net(inputs)
189186
loss = self.get_loss_value(data, outputs, labels)
@@ -196,20 +193,17 @@ def validation(self):
196193
avg_loss = running_loss / sample_num
197194
avg_score= running_score.double() / sample_num
198195
metrics = self.config['training'].get("evaluation_metric", "accuracy")
199-
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
200-
self.scheduler.step(avg_score)
201196
valid_scalers = {'loss': avg_loss, metrics: avg_score}
202197
return valid_scalers
203198

204199
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
205-
metrics =self.config['training'].get("evaluation_metric", "accuracy")
200+
metrics = self.config['training'].get("evaluation_metric", "accuracy")
206201
loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']}
207202
acc_scalar ={'train':train_scalars[metrics],'valid':valid_scalars[metrics]}
208203
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
209204
self.summ_writer.add_scalars(metrics, acc_scalar, glob_it)
210205
self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it)
211206

212-
logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it))
213207
logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format(
214208
train_scalars['loss'], metrics, train_scalars[metrics]))
215209
logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format(
@@ -251,7 +245,10 @@ def train_valid(self):
251245
checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start)
252246
self.checkpoint = torch.load(checkpoint_file, map_location = self.device)
253247
assert(self.checkpoint['iteration'] == iter_start)
254-
self.net.load_state_dict(self.checkpoint['model_state_dict'])
248+
if(len(device_ids) > 1):
249+
self.net.module.load_state_dict(self.checkpoint['model_state_dict'])
250+
else:
251+
self.net.load_state_dict(self.checkpoint['model_state_dict'])
255252
self.max_val_score = self.checkpoint.get('valid_pred', 0)
256253
self.max_val_it = self.checkpoint['iteration']
257254
self.best_model_wts = self.checkpoint['model_state_dict']
@@ -266,15 +263,28 @@ def train_valid(self):
266263
self.glob_it = iter_start
267264
for it in range(iter_start, iter_max, iter_valid):
268265
lr_value = self.optimizer.param_groups[0]['lr']
266+
t0 = time.time()
269267
train_scalars = self.training()
268+
t1 = time.time()
270269
valid_scalars = self.validation()
270+
t2 = time.time()
271+
if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)):
272+
self.scheduler.step(valid_scalars[metrics])
273+
else:
274+
self.scheduler.step()
275+
271276
self.glob_it = it + iter_valid
277+
logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it))
278+
logging.info('learning rate {0:}'.format(lr_value))
279+
logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1))
272280
self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it)
273-
274281
if(valid_scalars[metrics] > self.max_val_score):
275282
self.max_val_score = valid_scalars[metrics]
276283
self.max_val_it = self.glob_it
277-
self.best_model_wts = copy.deepcopy(self.net.state_dict())
284+
if(len(device_ids) > 1):
285+
self.best_model_wts = copy.deepcopy(self.net.module.state_dict())
286+
else:
287+
self.best_model_wts = copy.deepcopy(self.net.state_dict())
278288

279289
stop_now = True if(early_stop_it is not None and \
280290
self.glob_it - self.max_val_it > early_stop_it) else False
@@ -306,7 +316,6 @@ def train_valid(self):
306316
self.max_val_it, metrics, self.max_val_score))
307317
self.summ_writer.close()
308318

309-
310319
def infer(self):
311320
device_ids = self.config['testing']['gpus']
312321
device = torch.device("cuda:{0:}".format(device_ids[0]))

pymic/net_run/agent_seg.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,23 @@ def create_loss_calculator(self):
9696
raise ValueError("Undefined loss function {0:}".format(loss_name))
9797
else:
9898
base_loss = self.loss_dict[loss_name](self.config['training'])
99-
if(self.config['network'].get('deep_supervise', False)):
100-
weight = self.config['network'].get('deep_supervise_weight', None)
101-
params = {'deep_supervise_weight': weight, 'base_loss':base_loss}
99+
if(self.config['training'].get('deep_supervise', False)):
100+
weight = self.config['training'].get('deep_supervise_weight', None)
101+
mode = self.config['training'].get('deep_supervise_mode', 2)
102+
params = {'deep_supervise_weight': weight,
103+
'deep_supervise_mode': mode,
104+
'base_loss':base_loss}
102105
self.loss_calculator = DeepSuperviseLoss(params)
103106
else:
104107
self.loss_calculator = base_loss
105108

106109
def get_loss_value(self, data, pred, gt, param = None):
107110
loss_input_dict = {'prediction':pred, 'ground_truth': gt}
108111
if data.get('pixel_weight', None) is not None:
109-
loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device)
112+
if(isinstance(pred, tuple) or isinstance(pred, list)):
113+
loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred[0].device)
114+
else:
115+
loss_input_dict['pixel_weight'] = data['pixel_weight'].to(pred.device)
110116
loss_value = self.loss_calculator(loss_input_dict)
111117
return loss_value
112118

@@ -122,7 +128,7 @@ def set_postprocessor(self, postprocessor):
122128
def training(self):
123129
class_num = self.config['network']['class_num']
124130
iter_valid = self.config['training']['iter_valid']
125-
mixup_prob = self.config['training'].get('mixup_probability', 0.5)
131+
mixup_prob = self.config['training'].get('mixup_probability', 0.0)
126132
train_loss = 0
127133
train_dice_list = []
128134
self.net.train()
@@ -135,7 +141,7 @@ def training(self):
135141
# get the inputs
136142
inputs = self.convert_tensor_type(data['image'])
137143
labels_prob = self.convert_tensor_type(data['label_prob'])
138-
if(random() < mixup_prob):
144+
if(mixup_prob > 0 and random() < mixup_prob):
139145
inputs, labels_prob = mixup(inputs, labels_prob)
140146

141147
# # for debug
@@ -246,7 +252,10 @@ def train_valid(self):
246252
else:
247253
self.device = torch.device("cuda:{0:}".format(device_ids[0]))
248254
self.net.to(self.device)
255+
249256
ckpt_dir = self.config['training']['ckpt_save_dir']
257+
if(ckpt_dir[-1] == "/"):
258+
ckpt_dir = ckpt_dir[:-1]
250259
ckpt_prefix = self.config['training'].get('ckpt_prefix', None)
251260
if(ckpt_prefix is None):
252261
ckpt_prefix = ckpt_dir.split('/')[-1]

pymic/net_run/get_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def get_lr_scheduler(optimizer, sched_params):
5959
scheduler = lr_scheduler.MultiStepLR(optimizer,
6060
lr_milestones, lr_gamma, epoch_last)
6161
elif(keyword_match(name, "StepLR")):
62-
lr_step = sched_params["lr_step"] / val_it
62+
lr_step = sched_params["lr_step"] / val_it
6363
lr_gamma = sched_params["lr_gamma"]
6464
scheduler = lr_scheduler.StepLR(optimizer,
6565
lr_step, lr_gamma, epoch_last)

pymic/net_run_ssl/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,8 @@
11
from __future__ import absolute_import
2-
from . import *
2+
from pymic.net_run_ssl.ssl_abstract import *
3+
from pymic.net_run_ssl.ssl_cct import *
4+
from pymic.net_run_ssl.ssl_cps import *
5+
from pymic.net_run_ssl.ssl_em import *
6+
from pymic.net_run_ssl.ssl_mt import *
7+
from pymic.net_run_ssl.ssl_uamt import *
8+
from pymic.net_run_ssl.ssl_urpc import *

pymic/net_run_ssl/ssl_main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ def main():
3535
log_dir = config['training']['ckpt_save_dir']
3636
if(not os.path.exists(log_dir)):
3737
os.mkdir(log_dir)
38-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
39-
format='%(message)s')
38+
if sys.version.startswith("3.9"):
39+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
40+
format='%(message)s', force=True) # for python 3.9
41+
else:
42+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
43+
format='%(message)s') # for python 3.6
4044
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
4145
logging_config(config)
4246
ssl_method = config['semi_supervised_learning']['ssl_method']

pymic/net_run_ssl/ssl_mt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def training(self):
104104

105105
# update EMA
106106
alpha = ssl_cfg.get('ema_decay', 0.99)
107-
alpha = min(1 - 1 / (iter_max + 1), alpha)
107+
alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha)
108108
for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()):
109109
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
110110

pymic/net_run_ssl/ssl_uamt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def training(self):
106106

107107
# update EMA
108108
alpha = ssl_cfg.get('ema_decay', 0.99)
109-
alpha = min(1 - 1 / (iter_max + 1), alpha)
109+
alpha = min(1 - 1 / (self.glob_it / iter_valid + 1), alpha)
110110
for ema_param, param in zip(self.net_ema.parameters(), self.net.parameters()):
111111
ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
112112

pymic/net_run_wsl/wsl_main.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@ def main():
3434
log_dir = config['training']['ckpt_save_dir']
3535
if(not os.path.exists(log_dir)):
3636
os.mkdir(log_dir)
37-
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
38-
format='%(message)s')
37+
if sys.version.startswith("3.9"):
38+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
39+
format='%(message)s', force=True) # for python 3.9
40+
else:
41+
logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO,
42+
format='%(message)s') # for python 3.6
3943
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
4044
logging_config(config)
4145
wsl_method = config['weakly_supervised_learning']['wsl_method']

pymic/transform/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,13 @@
1+
# -*- coding: utf-8 -*-
12
from __future__ import absolute_import
2-
from . import *
3+
from pymic.transform.intensity import *
4+
from pymic.transform.flip import *
5+
from pymic.transform.pad import *
6+
from pymic.transform.rotate import *
7+
from pymic.transform.rescale import *
8+
from pymic.transform.transpose import *
9+
from pymic.transform.threshold import *
10+
from pymic.transform.normalize import *
11+
from pymic.transform.crop import *
12+
from pymic.transform.label_convert import *
13+
from pymic.transform.trans_dict import TransformDict

0 commit comments

Comments
 (0)