@@ -131,6 +131,7 @@ class Decoder(nn.Module):
131131 :param class_num: (int) The class number for segmentation task.
132132 :param trilinear: (bool) Using bilinear for up-sampling or not.
133133 If False, deconvolution will be used for up-sampling.
134+ :param multiscale_pred: (bool) Get multi-scale prediction.
134135 """
135136 def __init__ (self , params ):
136137 super (Decoder , self ).__init__ ()
@@ -139,16 +140,21 @@ def __init__(self, params):
139140 self .ft_chns = self .params ['feature_chns' ]
140141 self .dropout = self .params ['dropout' ]
141142 self .n_class = self .params ['class_num' ]
142- self .trilinear = self .params ['trilinear' ]
143+ self .trilinear = self .params .get ('trilinear' , True )
144+ self .mul_pred = self .params .get ('multiscale_pred' , False )
143145
144146 assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
145147
146148 if (len (self .ft_chns ) == 5 ):
147- self .up1 = UpBlock (self .ft_chns [4 ], self .ft_chns [3 ], self .ft_chns [3 ], self .dropout [3 ], self .bilinear )
148- self .up2 = UpBlock (self .ft_chns [3 ], self .ft_chns [2 ], self .ft_chns [2 ], self .dropout [2 ], self .bilinear )
149- self .up3 = UpBlock (self .ft_chns [2 ], self .ft_chns [1 ], self .ft_chns [1 ], self .dropout [1 ], self .bilinear )
150- self .up4 = UpBlock (self .ft_chns [1 ], self .ft_chns [0 ], self .ft_chns [0 ], self .dropout [0 ], self .bilinear )
149+ self .up1 = UpBlock (self .ft_chns [4 ], self .ft_chns [3 ], self .ft_chns [3 ], self .dropout [3 ], self .trilinear )
150+ self .up2 = UpBlock (self .ft_chns [3 ], self .ft_chns [2 ], self .ft_chns [2 ], self .dropout [2 ], self .trilinear )
151+ self .up3 = UpBlock (self .ft_chns [2 ], self .ft_chns [1 ], self .ft_chns [1 ], self .dropout [1 ], self .trilinear )
152+ self .up4 = UpBlock (self .ft_chns [1 ], self .ft_chns [0 ], self .ft_chns [0 ], self .dropout [0 ], self .trilinear )
151153 self .out_conv = nn .Conv3d (self .ft_chns [0 ], self .n_class , kernel_size = 1 )
154+ if (self .mul_pred ):
155+ self .out_conv1 = nn .Conv3d (self .ft_chns [1 ], self .n_class , kernel_size = 1 )
156+ self .out_conv2 = nn .Conv3d (self .ft_chns [2 ], self .n_class , kernel_size = 1 )
157+ self .out_conv3 = nn .Conv3d (self .ft_chns [3 ], self .n_class , kernel_size = 1 )
152158
153159 def forward (self , x ):
154160 if (len (self .ft_chns ) == 5 ):
@@ -163,6 +169,11 @@ def forward(self, x):
163169 x_d1 = self .up3 (x_d2 , x1 )
164170 x_d0 = self .up4 (x_d1 , x0 )
165171 output = self .out_conv (x_d0 )
172+ if (self .mul_pred ):
173+ output1 = self .out_conv1 (x_d1 )
174+ output2 = self .out_conv2 (x_d2 )
175+ output3 = self .out_conv3 (x_d3 )
176+ output = [output , output1 , output2 , output3 ]
166177 return output
167178
168179class UNet3D (nn .Module ):
@@ -187,7 +198,7 @@ class UNet3D(nn.Module):
187198 :param class_num: (int) The class number for segmentation task.
188199 :param trilinear: (bool) Using trilinear for up-sampling or not.
189200 If False, deconvolution will be used for up-sampling.
190- :param deep_supervise : (bool) Using deep supervision for training or not .
201+ :param multiscale_pred : (bool) Get multi-scale prediction .
191202 """
192203 def __init__ (self , params ):
193204 super (UNet3D , self ).__init__ ()
@@ -197,7 +208,7 @@ def __init__(self, params):
197208 self .dropout = self .params ['dropout' ]
198209 self .n_class = self .params ['class_num' ]
199210 self .trilinear = self .params ['trilinear' ]
200- self .deep_sup = self .params ['deep_supervise ' ]
211+ self .mul_pred = self .params ['multiscale_pred ' ]
201212 assert (len (self .ft_chns ) == 5 or len (self .ft_chns ) == 4 )
202213
203214 self .in_conv = ConvBlock (self .in_chns , self .ft_chns [0 ], self .dropout [0 ])
@@ -216,7 +227,7 @@ def __init__(self, params):
216227 dropout_p = self .dropout [0 ], trilinear = self .trilinear )
217228
218229 self .out_conv = nn .Conv3d (self .ft_chns [0 ], self .n_class , kernel_size = 1 )
219- if (self .deep_sup ):
230+ if (self .mul_pred ):
220231 self .out_conv1 = nn .Conv3d (self .ft_chns [1 ], self .n_class , kernel_size = 1 )
221232 self .out_conv2 = nn .Conv3d (self .ft_chns [2 ], self .n_class , kernel_size = 1 )
222233 self .out_conv3 = nn .Conv3d (self .ft_chns [3 ], self .n_class , kernel_size = 1 )
@@ -235,14 +246,10 @@ def forward(self, x):
235246 x_d1 = self .up3 (x_d2 , x1 )
236247 x_d0 = self .up4 (x_d1 , x0 )
237248 output = self .out_conv (x_d0 )
238- if (self .deep_sup ):
239- out_shape = list (output .shape )[2 :]
249+ if (self .mul_pred ):
240250 output1 = self .out_conv1 (x_d1 )
241- output1 = interpolate (output1 , out_shape , mode = 'trilinear' )
242251 output2 = self .out_conv2 (x_d2 )
243- output2 = interpolate (output2 , out_shape , mode = 'trilinear' )
244252 output3 = self .out_conv3 (x_d3 )
245- output3 = interpolate (output3 , out_shape , mode = 'trilinear' )
246253 output = [output , output1 , output2 , output3 ]
247254 return output
248255
@@ -251,7 +258,8 @@ def forward(self, x):
251258 'class_num' : 2 ,
252259 'feature_chns' :[2 , 8 , 32 , 64 ],
253260 'dropout' : [0 , 0 , 0 , 0.5 ],
254- 'trilinear' : True }
261+ 'trilinear' : True ,
262+ 'multiscale_pred' : False }
255263 Net = UNet3D (params )
256264 Net = Net .double ()
257265
0 commit comments