Skip to content

Commit 8c15798

Browse files
committed
update ssl and wsl fils
update ssl and wsl, print average foreground dice for each epoch
1 parent 49e1c64 commit 8c15798

18 files changed

Lines changed: 82 additions & 56 deletions

pymic/net/net2d/unet2d.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Decoder(nn.Module):
131131
:param class_num: (int) The class number for segmentation task.
132132
:param bilinear: (bool) Using bilinear for up-sampling or not.
133133
If False, deconvolution will be used for up-sampling.
134+
:param multiscale_pred: (bool) Get multi-scale prediction.
134135
"""
135136
def __init__(self, params):
136137
super(Decoder, self).__init__()
@@ -139,7 +140,8 @@ def __init__(self, params):
139140
self.ft_chns = self.params['feature_chns']
140141
self.dropout = self.params['dropout']
141142
self.n_class = self.params['class_num']
142-
self.bilinear = self.params['bilinear']
143+
self.bilinear = self.params.get('bilinear', True)
144+
self.mul_pred = self.params.get('multiscale_pred', False)
143145

144146
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
145147

@@ -149,6 +151,10 @@ def __init__(self, params):
149151
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
150152
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
151153
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
154+
if(self.mul_pred):
155+
self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1)
156+
self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1)
157+
self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1)
152158

153159
def forward(self, x):
154160
if(len(self.ft_chns) == 5):
@@ -163,6 +169,11 @@ def forward(self, x):
163169
x_d1 = self.up3(x_d2, x1)
164170
x_d0 = self.up4(x_d1, x0)
165171
output = self.out_conv(x_d0)
172+
if(self.mul_pred):
173+
output1 = self.out_conv1(x_d1)
174+
output2 = self.out_conv2(x_d2)
175+
output3 = self.out_conv3(x_d3)
176+
output = [output, output1, output2, output3]
166177
return output
167178

168179
class UNet2D(nn.Module):
@@ -187,7 +198,7 @@ class UNet2D(nn.Module):
187198
:param class_num: (int) The class number for segmentation task.
188199
:param bilinear: (bool) Using bilinear for up-sampling or not.
189200
If False, deconvolution will be used for up-sampling.
190-
:param deep_supervise: (bool) Using deep supervision for training or not.
201+
:param multiscale_pred: (bool) Get multiscale prediction.
191202
"""
192203
def __init__(self, params):
193204
super(UNet2D, self).__init__()
@@ -197,7 +208,7 @@ def __init__(self, params):
197208
self.dropout = self.params['dropout']
198209
self.n_class = self.params['class_num']
199210
self.bilinear = self.params['bilinear']
200-
self.deep_sup = self.params['deep_supervise']
211+
self.mul_pred = self.params['multiscale_pred']
201212

202213
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
203214

@@ -213,7 +224,7 @@ def __init__(self, params):
213224
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
214225

215226
self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1)
216-
if(self.deep_sup):
227+
if(self.mul_pred):
217228
self.out_conv1 = nn.Conv2d(self.ft_chns[1], self.n_class, kernel_size = 1)
218229
self.out_conv2 = nn.Conv2d(self.ft_chns[2], self.n_class, kernel_size = 1)
219230
self.out_conv3 = nn.Conv2d(self.ft_chns[3], self.n_class, kernel_size = 1)
@@ -239,7 +250,7 @@ def forward(self, x):
239250
x_d1 = self.up3(x_d2, x1)
240251
x_d0 = self.up4(x_d1, x0)
241252
output = self.out_conv(x_d0)
242-
if(self.deep_sup):
253+
if(self.mul_pred):
243254
output1 = self.out_conv1(x_d1)
244255
output2 = self.out_conv2(x_d2)
245256
output3 = self.out_conv3(x_d3)
@@ -261,7 +272,8 @@ def forward(self, x):
261272
'feature_chns':[2, 8, 32, 48, 64],
262273
'dropout': [0, 0, 0.3, 0.4, 0.5],
263274
'class_num': 2,
264-
'bilinear': True}
275+
'bilinear': True,
276+
'multiscale_pred': False}
265277
Net = UNet2D(params)
266278
Net = Net.double()
267279

pymic/net/net3d/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from __future__ import absolute_import
2+
from . import *

pymic/net/net3d/unet3d.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class Decoder(nn.Module):
131131
:param class_num: (int) The class number for segmentation task.
132132
:param trilinear: (bool) Using bilinear for up-sampling or not.
133133
If False, deconvolution will be used for up-sampling.
134+
:param multiscale_pred: (bool) Get multi-scale prediction.
134135
"""
135136
def __init__(self, params):
136137
super(Decoder, self).__init__()
@@ -139,16 +140,21 @@ def __init__(self, params):
139140
self.ft_chns = self.params['feature_chns']
140141
self.dropout = self.params['dropout']
141142
self.n_class = self.params['class_num']
142-
self.trilinear = self.params['trilinear']
143+
self.trilinear = self.params.get('trilinear', True)
144+
self.mul_pred = self.params.get('multiscale_pred', False)
143145

144146
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
145147

146148
if(len(self.ft_chns) == 5):
147-
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear)
148-
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear)
149-
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear)
150-
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear)
149+
self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.trilinear)
150+
self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.trilinear)
151+
self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.trilinear)
152+
self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.trilinear)
151153
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
154+
if(self.mul_pred):
155+
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
156+
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
157+
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)
152158

153159
def forward(self, x):
154160
if(len(self.ft_chns) == 5):
@@ -163,6 +169,11 @@ def forward(self, x):
163169
x_d1 = self.up3(x_d2, x1)
164170
x_d0 = self.up4(x_d1, x0)
165171
output = self.out_conv(x_d0)
172+
if(self.mul_pred):
173+
output1 = self.out_conv1(x_d1)
174+
output2 = self.out_conv2(x_d2)
175+
output3 = self.out_conv3(x_d3)
176+
output = [output, output1, output2, output3]
166177
return output
167178

168179
class UNet3D(nn.Module):
@@ -187,7 +198,7 @@ class UNet3D(nn.Module):
187198
:param class_num: (int) The class number for segmentation task.
188199
:param trilinear: (bool) Using trilinear for up-sampling or not.
189200
If False, deconvolution will be used for up-sampling.
190-
:param deep_supervise: (bool) Using deep supervision for training or not.
201+
:param multiscale_pred: (bool) Get multi-scale prediction.
191202
"""
192203
def __init__(self, params):
193204
super(UNet3D, self).__init__()
@@ -197,7 +208,7 @@ def __init__(self, params):
197208
self.dropout = self.params['dropout']
198209
self.n_class = self.params['class_num']
199210
self.trilinear = self.params['trilinear']
200-
self.deep_sup = self.params['deep_supervise']
211+
self.mul_pred = self.params['multiscale_pred']
201212
assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4)
202213

203214
self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0])
@@ -216,7 +227,7 @@ def __init__(self, params):
216227
dropout_p = self.dropout[0], trilinear=self.trilinear)
217228

218229
self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1)
219-
if(self.deep_sup):
230+
if(self.mul_pred):
220231
self.out_conv1 = nn.Conv3d(self.ft_chns[1], self.n_class, kernel_size = 1)
221232
self.out_conv2 = nn.Conv3d(self.ft_chns[2], self.n_class, kernel_size = 1)
222233
self.out_conv3 = nn.Conv3d(self.ft_chns[3], self.n_class, kernel_size = 1)
@@ -235,14 +246,10 @@ def forward(self, x):
235246
x_d1 = self.up3(x_d2, x1)
236247
x_d0 = self.up4(x_d1, x0)
237248
output = self.out_conv(x_d0)
238-
if(self.deep_sup):
239-
out_shape = list(output.shape)[2:]
249+
if(self.mul_pred):
240250
output1 = self.out_conv1(x_d1)
241-
output1 = interpolate(output1, out_shape, mode = 'trilinear')
242251
output2 = self.out_conv2(x_d2)
243-
output2 = interpolate(output2, out_shape, mode = 'trilinear')
244252
output3 = self.out_conv3(x_d3)
245-
output3 = interpolate(output3, out_shape, mode = 'trilinear')
246253
output = [output, output1, output2, output3]
247254
return output
248255

@@ -251,7 +258,8 @@ def forward(self, x):
251258
'class_num': 2,
252259
'feature_chns':[2, 8, 32, 64],
253260
'dropout' : [0, 0, 0, 0.5],
254-
'trilinear': True}
261+
'trilinear': True,
262+
'multiscale_pred': False}
255263
Net = UNet3D(params)
256264
Net = Net.double()
257265

pymic/net_run_ssl/ssl_abstract.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
8383
'valid':valid_scalars['loss']}
8484
loss_sup_scalar = {'train':train_scalars['loss_sup']}
8585
loss_upsup_scalar = {'train':train_scalars['loss_reg']}
86-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
86+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
8787
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
8888
self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it)
8989
self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it)
@@ -95,11 +95,11 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
9595
cls_dice_scalar = {'train':train_scalars['class_dice'][c], \
9696
'valid':valid_scalars['class_dice'][c]}
9797
self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it)
98-
logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format(
99-
train_scalars['loss'], train_scalars['avg_dice']) + "[" + \
98+
logging.info('train loss {0:.4f}, avg foreground dice {1:.4f} '.format(
99+
train_scalars['loss'], train_scalars['avg_fg_dice']) + "[" + \
100100
' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]")
101-
logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format(
102-
valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \
101+
logging.info('valid loss {0:.4f}, avg foreground dice {1:.4f} '.format(
102+
valid_scalars['loss'], valid_scalars['avg_fg_dice']) + "[" + \
103103
' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]")
104104

105105
def train_valid(self):

pymic/net_run_ssl/ssl_cct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ def training(self):
154154
train_avg_loss_sup = train_loss_sup / iter_valid
155155
train_avg_loss_reg = train_loss_reg / iter_valid
156156
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
157-
train_avg_dice = train_cls_dice.mean()
157+
train_avg_dice = train_cls_dice[1:].mean()
158158

159159
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
160160
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
161-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
161+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
162162
return train_scalers

pymic/net_run_ssl/ssl_cps.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,12 @@ def training(self):
137137
train_avg_loss_pse_sup1 = train_loss_pseudo_sup1 / iter_valid
138138
train_avg_loss_pse_sup2 = train_loss_pseudo_sup2 / iter_valid
139139
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
140-
train_avg_dice = train_cls_dice.mean()
140+
train_avg_dice = train_cls_dice[1:].mean()
141141

142142
train_scalers = {'loss': train_avg_loss,
143143
'loss_sup1':train_avg_loss_sup1, 'loss_sup2': train_avg_loss_sup2,
144144
'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2,
145-
'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
145+
'regular_w':regular_w, 'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
146146
return train_scalers
147147

148148
def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
@@ -152,7 +152,7 @@ def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it):
152152
'net2':train_scalars['loss_sup2']}
153153
loss_pse_sup_scalar = {'net1':train_scalars['loss_pse_sup1'],
154154
'net2':train_scalars['loss_pse_sup2']}
155-
dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']}
155+
dice_scalar ={'train':train_scalars['avg_fg_dice'], 'valid':valid_scalars['avg_fg_dice']}
156156
self.summ_writer.add_scalars('loss', loss_scalar, glob_it)
157157
self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it)
158158
self.summ_writer.add_scalars('loss_pseudo_sup', loss_pse_sup_scalar, glob_it)

pymic/net_run_ssl/ssl_em.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,9 @@ def training(self):
9898
train_avg_loss_sup = train_loss_sup / iter_valid
9999
train_avg_loss_reg = train_loss_reg / iter_valid
100100
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
101-
train_avg_dice = train_cls_dice.mean()
101+
train_avg_dice = train_cls_dice[1:].mean()
102102

103103
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
104104
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
105-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
105+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
106106
return train_scalers

pymic/net_run_ssl/ssl_mt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def training(self):
123123
train_avg_loss_sup = train_loss_sup / iter_valid
124124
train_avg_loss_reg = train_loss_reg / iter_valid
125125
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
126-
train_avg_dice = train_cls_dice.mean()
126+
train_avg_dice = train_cls_dice[1:].mean()
127127

128128
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
129129
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
130-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
130+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
131131
return train_scalers

pymic/net_run_ssl/ssl_uamt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def training(self):
125125
train_avg_loss_sup = train_loss_sup / iter_valid
126126
train_avg_loss_reg = train_loss_reg / iter_valid
127127
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
128-
train_avg_dice = train_cls_dice.mean()
128+
train_avg_dice = train_cls_dice[1:].mean()
129129

130130
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
131131
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
132-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
132+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
133133
return train_scalers

pymic/net_run_ssl/ssl_urpc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ def training(self):
111111
train_avg_loss_sup = train_loss_sup / iter_valid
112112
train_avg_loss_reg = train_loss_reg / iter_valid
113113
train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
114-
train_avg_dice = train_cls_dice.mean()
114+
train_avg_dice = train_cls_dice[1:].mean()
115115

116116
train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup,
117117
'loss_reg':train_avg_loss_reg, 'regular_w':regular_w,
118-
'avg_dice':train_avg_dice, 'class_dice': train_cls_dice}
118+
'avg_fg_dice':train_avg_dice, 'class_dice': train_cls_dice}
119119
return train_scalers

0 commit comments

Comments
 (0)