Skip to content

Commit af82d38

Browse files
committed
update transform and image process functions
1 parent e6c461b commit af82d38

10 files changed

Lines changed: 644 additions & 129 deletions

File tree

pymic/io/transform3d.py

Lines changed: 137 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import json
6+
import math
67
import random
78
import numpy as np
89

@@ -59,7 +60,11 @@ def inverse_transform_for_prediction(self, sample):
5960
different elemenets in the batch.
6061
6162
origin_shape is a 4D or 3D vector as saved in __call__().'''
62-
origin_shape = json.loads(sample['Rescale_origin_shape'][0])
63+
if(isinstance(sample['Rescale_origin_shape'], list) or \
64+
isinstance(sample['Rescale_origin_shape'], tuple)):
65+
origin_shape = json.loads(sample['Rescale_origin_shape'][0])
66+
else:
67+
origin_shape = json.loads(sample['Rescale_origin_shape'])
6368
origin_dim = len(origin_shape) - 1
6469
predict = sample['predict']
6570
input_shape = predict.shape
@@ -116,7 +121,11 @@ def inverse_transform_for_prediction(self, sample):
116121
different elemenets in the batch.
117122
118123
flip_axis is a list as saved in __call__().'''
119-
flip_axis = json.loads(sample['RandomFlip_Param'][0])
124+
if(isinstance(sample['RandomFlip_Param'], list) or \
125+
isinstance(sample['RandomFlip_Param'], tuple)):
126+
flip_axis = json.loads(sample['RandomFlip_Param'][0])
127+
else:
128+
flip_axis = json.loads(sample['RandomFlip_Param'])
120129
if(len(flip_axis) > 0):
121130
sample['predict'] = np.flip(sample['predict'] , flip_axis)
122131
return sample
@@ -180,7 +189,11 @@ def inverse_transform_for_prediction(self, sample):
180189
181190
transform_param_list is a list as saved in __call__().'''
182191
# get the paramters for invers transformation
183-
transform_param_list = json.loads(sample['RandomRotate_Param'][0])
192+
if(isinstance(sample['RandomRotate_Param'], list) or \
193+
isinstance(sample['RandomRotate_Param'], tuple)):
194+
transform_param_list = json.loads(sample['RandomRotate_Param'][0])
195+
else:
196+
transform_param_list = json.loads(sample['RandomRotate_Param'])
184197
transform_param_list.reverse()
185198
for i in range(len(transform_param_list)):
186199
transform_param_list[i][0] = - transform_param_list[i][0]
@@ -196,8 +209,9 @@ class Pad(object):
196209
output_size (tuple/list): the size along each spatial axis.
197210
198211
"""
199-
def __init__(self, output_size, inverse = True):
212+
def __init__(self, output_size, ceil_mode = False, inverse = True):
200213
self.output_size = output_size
214+
self.ceil_mode = ceil_mode
201215
self.inverse = inverse
202216

203217

@@ -206,7 +220,14 @@ def __call__(self, sample):
206220
input_shape = image.shape
207221
input_dim = len(input_shape) - 1
208222
assert(len(self.output_size) == input_dim)
209-
margin = [max(0, self.output_size[i] - input_shape[1+i]) \
223+
if(self.ceil_mode):
224+
multiple = [int(math.ceil(float(input_shape[1+i])/self.output_size[i]))\
225+
for i in range(input_dim)]
226+
output_size = [multiple[i] * self.output_size[i] \
227+
for i in range(input_dim)]
228+
else:
229+
output_size = self.output_size
230+
margin = [max(0, output_size[i] - input_shape[1+i]) \
210231
for i in range(input_dim)]
211232

212233
margin_lower = [int(margin[i] / 2) for i in range(input_dim)]
@@ -233,7 +254,10 @@ def inverse_transform_for_prediction(self, sample):
233254
234255
origin_shape is a 4D or 3D vector as saved in __call__().'''
235256
# raise ValueError("not implemented")
236-
params = json.loads(sample['Pad_Param'][0])
257+
if(isinstance(sample['Pad_Param'], list) or isinstance(sample['Pad_Param'], tuple)):
258+
params = json.loads(sample['Pad_Param'][0])
259+
else:
260+
params = json.loads(sample['Pad_Param'])
237261
margin_lower = params[0]
238262
margin_upper = params[1]
239263
predict = sample['predict']
@@ -310,7 +334,11 @@ def inverse_transform_for_prediction(self, sample):
310334
different elemenets in the batch.
311335
312336
origin_shape is a 4D or 3D vector as saved in __call__().'''
313-
params = json.loads(sample['CropWithBoundingBox_Param'][0])
337+
if(isinstance(sample['CropWithBoundingBox_Param'], list) or \
338+
isinstance(sample['CropWithBoundingBox_Param'], tuple)):
339+
params = json.loads(sample['CropWithBoundingBox_Param'][0])
340+
else:
341+
params = json.loads(sample['CropWithBoundingBox_Param'])
314342
origin_shape = params[0]
315343
crop_min = params[1]
316344
crop_max = params[2]
@@ -333,10 +361,13 @@ class RandomCrop(object):
333361
the output channel is the same as the input channel.
334362
"""
335363

336-
def __init__(self, output_size, inverse = True):
364+
def __init__(self, output_size, fg_focus = False, fg_ratio = 0.0, mask_label = None, inverse = True):
337365
assert isinstance(output_size, (list, tuple))
338366
self.output_size = output_size
339-
self.inverse = inverse
367+
self.inverse = inverse
368+
self.fg_focus = fg_focus
369+
self.fg_ratio = fg_ratio
370+
self.mask_label = mask_label
340371

341372
def __call__(self, sample):
342373
image = sample['image']
@@ -347,6 +378,19 @@ def __call__(self, sample):
347378
crop_margin = [input_shape[i + 1] - self.output_size[i]\
348379
for i in range(input_dim)]
349380
crop_min = [random.randint(0, item) for item in crop_margin]
381+
if(self.fg_focus and random.random() < self.fg_ratio):
382+
label = sample['label']
383+
mask = np.zeros_like(label)
384+
for temp_lab in self.mask_label:
385+
mask = np.maximum(mask, label == temp_lab)
386+
bb_min, bb_max = get_ND_bounding_box(mask)
387+
bb_min, bb_max = bb_min[1:], bb_max[1:]
388+
crop_min = [random.randint(bb_min[i], bb_max[i]) - int(self.output_size[i]/2) \
389+
for i in range(input_dim)]
390+
crop_min = [max(0, item) for item in crop_min]
391+
crop_min = [min(crop_min[i], input_shape[i+1] - self.output_size[i]) \
392+
for i in range(input_dim)]
393+
350394
crop_max = [crop_min[i] + self.output_size[i] \
351395
for i in range(input_dim)]
352396
crop_min = [0] + crop_min
@@ -368,7 +412,11 @@ def inverse_transform_for_prediction(self, sample):
368412
different elemenets in the batch.
369413
370414
origin_shape is a 4D or 3D vector as saved in __call__().'''
371-
params = json.loads(sample['RandomCrop_Param'][0])
415+
if(isinstance(sample['RandomCrop_Param'], list) or \
416+
isinstance(sample['RandomCrop_Param'], tuple)):
417+
params = json.loads(sample['RandomCrop_Param'][0])
418+
else:
419+
params = json.loads(sample['RandomCrop_Param'])
372420
origin_shape = params[0]
373421
crop_min = params[1]
374422
crop_max = params[2]
@@ -416,27 +464,40 @@ class ChannelWiseNormalize(object):
416464
mean (None or tuple/list): The mean values along each channel.
417465
std (None or tuple/list): The std values along each channel.
418466
if mean and std are None, calculate them from non-zero region
419-
zero_to_random (bool): If true, replace zero values with random values.
467+
chns (None, or tuple/list): The list of channel indices
468+
zero_to_random (bool, or tuple/list or bool): indicate whether zero values
469+
in each channel is replaced with random values.
420470
"""
421-
def __init__(self, mean, std, zero_to_random = False, inverse = False):
471+
def __init__(self, mean, std, chns = None, zero_to_random = False, inverse = False):
422472
self.mean = mean
423473
self.std = std
474+
self.chns = chns
424475
self.zero_to_random = zero_to_random
425476
self.inverse = inverse
426477

427478
def __call__(self, sample):
428479
image= sample['image']
429480
mask = image[0] > 0
430-
for chn in range(image.shape[0]):
481+
chns = self.chns
482+
if(chns is None):
483+
chns = range(image.shape[0])
484+
zero_to_random = self.zero_to_random
485+
if(isinstance(zero_to_random, bool)):
486+
zero_to_random = [zero_to_random]*len(chns)
487+
if(not(self.mean is None and self.std is None)):
488+
assert(len(self.mean) == len(self.std))
489+
assert(len(self.mean) == len(chns))
490+
for i in range(len(chns)):
491+
chn = chns[i]
431492
if(self.mean is None and self.std is None):
432493
pixels = image[chn][mask > 0]
433494
chn_mean = pixels.mean()
434495
chn_std = pixels.std()
435496
else:
436-
chn_mean = self.mean[chn]
437-
chn_std = self.std[chn]
497+
chn_mean = self.mean[i]
498+
chn_std = self.std[i]
438499
chn_norm = (image[chn] - chn_mean)/chn_std
439-
if(self.zero_to_random):
500+
if(zero_to_random[i]):
440501
chn_random = np.random.normal(0, 1, size = chn_norm.shape)
441502
chn_norm[mask == 0] = chn_random[mask == 0]
442503
image[chn] = chn_norm
@@ -490,6 +551,49 @@ def __call__(self, sample):
490551
def inverse_transform_for_prediction(self, sample):
491552
raise(ValueError("not implemented"))
492553

554+
class LabelToProbability(object):
555+
"""
556+
Convert one-channel label map to multi-channel probability map
557+
Args:
558+
class_num (int): the class number in the label map
559+
"""
560+
def __init__(self, class_num, inverse = False):
561+
self.class_num = class_num
562+
self.inverse = inverse
563+
564+
def __call__(self, sample):
565+
label = sample['label'][0]
566+
label_prob = []
567+
for i in range(self.class_num):
568+
temp_prob = label == i*np.ones_like(label)
569+
label_prob.append(temp_prob)
570+
label_prob = np.asarray(label_prob, np.float32)
571+
572+
sample['label_prob'] = label_prob
573+
return sample
574+
575+
def inverse_transform_for_prediction(self, sample):
576+
raise(ValueError("not implemented"))
577+
578+
class ProbabilityToDistance(object):
579+
"""
580+
get distance transform for each label
581+
"""
582+
def __init__(self, inverse = False):
583+
self.inverse = inverse
584+
585+
586+
def __call__(self, sample):
587+
label_prob = sample['label_prob']
588+
label_distance = []
589+
for i in range(label_prob.shape[0]):
590+
temp_lab = label_prob[i]
591+
temp_dis = get_euclidean_distance(temp_lab, dim = 3, spacing = [1.0, 1.0, 1.0])
592+
label_distance.append(temp_dis)
593+
label_distance = np.asarray(label_distance)
594+
sample['label_distance'] = label_distance
595+
return sample
596+
493597
class RegionSwop(object):
494598
"""
495599
Swop a subregion randomly between two images and their corresponding label
@@ -567,8 +671,9 @@ def get_transform(name, params):
567671

568672
elif(name == "Pad"):
569673
output_size = params["Pad_output_size".lower()]
674+
ceil_mode = params["Pad_ceil_mode".lower()]
570675
inverse = params["Pad_inverse".lower()]
571-
return Pad(output_size, inverse)
676+
return Pad(output_size, ceil_mode, inverse)
572677

573678
elif(name == "ChannelWiseGammaCorrection"):
574679
gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()]
@@ -577,11 +682,12 @@ def get_transform(name, params):
577682
return ChannelWiseGammaCorrection(gamma_min, gamma_max, inverse)
578683

579684
elif (name == 'ChannelWiseNormalize'):
685+
chns = params['ChannelWiseNormalize_channels'.lower()]
580686
mean = params['ChannelWiseNormalize_mean'.lower()]
581687
std = params['ChannelWiseNormalize_std'.lower()]
582688
zero_to_random = params['ChannelWiseNormalize_zero_to_random'.lower()]
583689
inverse = params['ChannelWiseNormalize_inverse'.lower()]
584-
return ChannelWiseNormalize(mean, std, zero_to_random, inverse)
690+
return ChannelWiseNormalize(mean, std, chns, zero_to_random, inverse)
585691

586692
elif(name == 'ChannelWiseThreshold'):
587693
threshold = params['ChannelWiseThreshold_threshold'.lower()]
@@ -594,10 +700,22 @@ def get_transform(name, params):
594700
inverse = params['LabelConvert_inverse'.lower()]
595701
return LabelConvert(source_list, target_list, inverse)
596702

703+
elif(name == 'LabelToProbability'):
704+
class_num = params['LabelToProbability_class_num'.lower()]
705+
inverse = params['LabelToProbability_inverse'.lower()]
706+
return LabelToProbability(class_num, inverse)
707+
708+
elif(name == 'ProbabilityToDistance'):
709+
inverse = params['ProbabilityToDistance_inverse'.lower()]
710+
return ProbabilityToDistance(inverse)
711+
597712
elif(name == 'RandomCrop'):
598713
output_size = params['RandomCrop_output_size'.lower()]
714+
fg_focus = params['RandomCrop_foreground_focus'.lower()]
715+
fg_ratio = params['RandomCrop_foreground_ratio'.lower()]
716+
mask_label = params['RandomCrop_mask_label'.lower()]
599717
inverse = params['RandomCrop_inverse'.lower()]
600-
return RandomCrop(output_size, inverse)
718+
return RandomCrop(output_size, fg_focus, fg_ratio, mask_label, inverse)
601719

602720
elif(name == 'RegionSwop'):
603721
spatial_axes = params['RegionSwop_spatial_axes'.lower()]

pymic/layer/convolution.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,69 @@ def forward(self, x):
3838
if(self.acti_func is not None):
3939
f = self.acti_func(f)
4040
return f
41+
42+
class ConvolutionSepAll3DLayer(nn.Module):
43+
"""
44+
A compose layer with the following components:
45+
convolution -> (batch_norm) -> activation -> (dropout)
46+
batch norm and dropout are optional
47+
"""
48+
def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
49+
stride = 1, padding = 0, dilation =1, groups = 1, bias = True,
50+
batch_norm = True, acti_func = None):
51+
super(ConvolutionSepAll3DLayer, self).__init__()
52+
self.n_in_chns = in_channels
53+
self.n_out_chns = out_channels
54+
self.batch_norm = batch_norm
55+
self.acti_func = acti_func
56+
57+
assert(dim == 3)
58+
chn = min(in_channels, out_channels)
59+
60+
self.conv_intra_plane1 = nn.Conv2d(chn, chn,
61+
kernel_size, stride, padding, dilation, chn, bias)
62+
63+
self.conv_intra_plane2 = nn.Conv2d(chn, chn,
64+
kernel_size, stride, padding, dilation, chn, bias)
65+
66+
self.conv_intra_plane3 = nn.Conv2d(chn, chn,
67+
kernel_size, stride, padding, dilation, chn, bias)
68+
69+
self.conv_space_wise = nn.Conv2d(in_channels, out_channels,
70+
1, stride, 0, dilation, 1, bias)
71+
72+
if(self.batch_norm):
73+
self.bn = nn.modules.BatchNorm3d(out_channels)
74+
75+
def forward(self, x):
76+
in_shape = list(x.shape)
77+
assert(len(in_shape) == 5)
78+
[B, C, D, H, W] = in_shape
79+
f0 = x.permute(0, 2, 1, 3, 4) #[B, D, C, H, W]
80+
f0 = f0.contiguous().view([B*D, C, H, W])
81+
82+
Cc = min(self.n_in_chns, self.n_out_chns)
83+
Co = self.n_out_chns
84+
if(self.n_in_chns > self.n_out_chns):
85+
f0 = self.conv_space_wise(f0) #[B*D, Cc, H, W]
86+
87+
f1 = self.conv_intra_plane1(f0)
88+
f2 = f1.contiguous().view([B, D, Cc, H, W])
89+
f2 = f2.permute(0, 3, 2, 1, 4) #[B, H, Cc, D, W]
90+
f2 = f2.contiguous().view([B*H, Cc, D, W])
91+
f2 = self.conv_intra_plane2(f2)
92+
f3 = f2.contiguous().view([B, H, Cc, D, W])
93+
f3 = f3.permute(0, 4, 2, 3, 1) #[B, W, Cc, D, H]
94+
f3 = f3.contiguous().view([B*W, Cc, D, H])
95+
f3 = self.conv_intra_plane3(f3)
96+
if(self.n_in_chns <= self.n_out_chns):
97+
f3 = self.conv_space_wise(f3) #[B*W, Co, D, H]
98+
99+
f3 = f3.contiguous().view([B, W, Co, D, H])
100+
f3 = f3.permute([0, 2, 3, 4, 1]) #[B, Co, D, H, W]
101+
102+
if(self.batch_norm):
103+
f3 = self.bn(f3)
104+
if(self.acti_func is not None):
105+
f3 = self.acti_func(f3)
106+
return f3

pymic/net2d/unet2d.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def __init__(self, params):
8181
kernel_size = 3, padding = 1)
8282

8383
def forward(self, x):
84+
x_shape = list(x.shape)
85+
if(len(x_shape) == 5):
86+
[N, C, D, H, W] = x_shape
87+
new_shape = [N*D, C, H, W]
88+
x = torch.transpose(x, 1, 2)
89+
x = torch.reshape(x, new_shape)
8490
f1 = self.block1(x); d1 = self.down1(f1)
8591
f2 = self.block2(d1); d2 = self.down2(f2)
8692
f3 = self.block3(d2); d3 = self.down3(f3)
@@ -104,6 +110,10 @@ def forward(self, x):
104110
f9 = self.block9(f1cat)
105111

106112
output = self.conv(f9)
113+
if(len(x_shape) == 5):
114+
new_shape = [N, D] + list(output.shape)[1:]
115+
output = torch.reshape(output, new_shape)
116+
output = torch.transpose(output, 1, 2)
107117
return output
108118

109119
if __name__ == "__main__":
@@ -114,7 +124,7 @@ def forward(self, x):
114124
Net = UNet2D(params)
115125
Net = Net.double()
116126

117-
x = np.random.rand(4, 4, 96, 96)
127+
x = np.random.rand(4, 4, 10, 96, 96)
118128
xt = torch.from_numpy(x)
119129
xt = torch.tensor(xt)
120130

0 commit comments

Comments
 (0)