33
44import torch
55import json
6+ import math
67import random
78import numpy as np
89
@@ -59,7 +60,11 @@ def inverse_transform_for_prediction(self, sample):
5960 different elemenets in the batch.
6061
6162 origin_shape is a 4D or 3D vector as saved in __call__().'''
62- origin_shape = json .loads (sample ['Rescale_origin_shape' ][0 ])
63+ if (isinstance (sample ['Rescale_origin_shape' ], list ) or \
64+ isinstance (sample ['Rescale_origin_shape' ], tuple )):
65+ origin_shape = json .loads (sample ['Rescale_origin_shape' ][0 ])
66+ else :
67+ origin_shape = json .loads (sample ['Rescale_origin_shape' ])
6368 origin_dim = len (origin_shape ) - 1
6469 predict = sample ['predict' ]
6570 input_shape = predict .shape
@@ -116,7 +121,11 @@ def inverse_transform_for_prediction(self, sample):
116121 different elemenets in the batch.
117122
118123 flip_axis is a list as saved in __call__().'''
119- flip_axis = json .loads (sample ['RandomFlip_Param' ][0 ])
124+ if (isinstance (sample ['RandomFlip_Param' ], list ) or \
125+ isinstance (sample ['RandomFlip_Param' ], tuple )):
126+ flip_axis = json .loads (sample ['RandomFlip_Param' ][0 ])
127+ else :
128+ flip_axis = json .loads (sample ['RandomFlip_Param' ])
120129 if (len (flip_axis ) > 0 ):
121130 sample ['predict' ] = np .flip (sample ['predict' ] , flip_axis )
122131 return sample
@@ -180,7 +189,11 @@ def inverse_transform_for_prediction(self, sample):
180189
181190 transform_param_list is a list as saved in __call__().'''
182191 # get the paramters for invers transformation
183- transform_param_list = json .loads (sample ['RandomRotate_Param' ][0 ])
192+ if (isinstance (sample ['RandomRotate_Param' ], list ) or \
193+ isinstance (sample ['RandomRotate_Param' ], tuple )):
194+ transform_param_list = json .loads (sample ['RandomRotate_Param' ][0 ])
195+ else :
196+ transform_param_list = json .loads (sample ['RandomRotate_Param' ])
184197 transform_param_list .reverse ()
185198 for i in range (len (transform_param_list )):
186199 transform_param_list [i ][0 ] = - transform_param_list [i ][0 ]
@@ -196,8 +209,9 @@ class Pad(object):
196209 output_size (tuple/list): the size along each spatial axis.
197210
198211 """
199- def __init__ (self , output_size , inverse = True ):
212+ def __init__ (self , output_size , ceil_mode = False , inverse = True ):
200213 self .output_size = output_size
214+ self .ceil_mode = ceil_mode
201215 self .inverse = inverse
202216
203217
@@ -206,7 +220,14 @@ def __call__(self, sample):
206220 input_shape = image .shape
207221 input_dim = len (input_shape ) - 1
208222 assert (len (self .output_size ) == input_dim )
209- margin = [max (0 , self .output_size [i ] - input_shape [1 + i ]) \
223+ if (self .ceil_mode ):
224+ multiple = [int (math .ceil (float (input_shape [1 + i ])/ self .output_size [i ]))\
225+ for i in range (input_dim )]
226+ output_size = [multiple [i ] * self .output_size [i ] \
227+ for i in range (input_dim )]
228+ else :
229+ output_size = self .output_size
230+ margin = [max (0 , output_size [i ] - input_shape [1 + i ]) \
210231 for i in range (input_dim )]
211232
212233 margin_lower = [int (margin [i ] / 2 ) for i in range (input_dim )]
@@ -233,7 +254,10 @@ def inverse_transform_for_prediction(self, sample):
233254
234255 origin_shape is a 4D or 3D vector as saved in __call__().'''
235256 # raise ValueError("not implemented")
236- params = json .loads (sample ['Pad_Param' ][0 ])
257+ if (isinstance (sample ['Pad_Param' ], list ) or isinstance (sample ['Pad_Param' ], tuple )):
258+ params = json .loads (sample ['Pad_Param' ][0 ])
259+ else :
260+ params = json .loads (sample ['Pad_Param' ])
237261 margin_lower = params [0 ]
238262 margin_upper = params [1 ]
239263 predict = sample ['predict' ]
@@ -310,7 +334,11 @@ def inverse_transform_for_prediction(self, sample):
310334 different elemenets in the batch.
311335
312336 origin_shape is a 4D or 3D vector as saved in __call__().'''
313- params = json .loads (sample ['CropWithBoundingBox_Param' ][0 ])
337+ if (isinstance (sample ['CropWithBoundingBox_Param' ], list ) or \
338+ isinstance (sample ['CropWithBoundingBox_Param' ], tuple )):
339+ params = json .loads (sample ['CropWithBoundingBox_Param' ][0 ])
340+ else :
341+ params = json .loads (sample ['CropWithBoundingBox_Param' ])
314342 origin_shape = params [0 ]
315343 crop_min = params [1 ]
316344 crop_max = params [2 ]
@@ -333,10 +361,13 @@ class RandomCrop(object):
333361 the output channel is the same as the input channel.
334362 """
335363
336- def __init__ (self , output_size , inverse = True ):
364+ def __init__ (self , output_size , fg_focus = False , fg_ratio = 0.0 , mask_label = None , inverse = True ):
337365 assert isinstance (output_size , (list , tuple ))
338366 self .output_size = output_size
339- self .inverse = inverse
367+ self .inverse = inverse
368+ self .fg_focus = fg_focus
369+ self .fg_ratio = fg_ratio
370+ self .mask_label = mask_label
340371
341372 def __call__ (self , sample ):
342373 image = sample ['image' ]
@@ -347,6 +378,19 @@ def __call__(self, sample):
347378 crop_margin = [input_shape [i + 1 ] - self .output_size [i ]\
348379 for i in range (input_dim )]
349380 crop_min = [random .randint (0 , item ) for item in crop_margin ]
381+ if (self .fg_focus and random .random () < self .fg_ratio ):
382+ label = sample ['label' ]
383+ mask = np .zeros_like (label )
384+ for temp_lab in self .mask_label :
385+ mask = np .maximum (mask , label == temp_lab )
386+ bb_min , bb_max = get_ND_bounding_box (mask )
387+ bb_min , bb_max = bb_min [1 :], bb_max [1 :]
388+ crop_min = [random .randint (bb_min [i ], bb_max [i ]) - int (self .output_size [i ]/ 2 ) \
389+ for i in range (input_dim )]
390+ crop_min = [max (0 , item ) for item in crop_min ]
391+ crop_min = [min (crop_min [i ], input_shape [i + 1 ] - self .output_size [i ]) \
392+ for i in range (input_dim )]
393+
350394 crop_max = [crop_min [i ] + self .output_size [i ] \
351395 for i in range (input_dim )]
352396 crop_min = [0 ] + crop_min
@@ -368,7 +412,11 @@ def inverse_transform_for_prediction(self, sample):
368412 different elemenets in the batch.
369413
370414 origin_shape is a 4D or 3D vector as saved in __call__().'''
371- params = json .loads (sample ['RandomCrop_Param' ][0 ])
415+ if (isinstance (sample ['RandomCrop_Param' ], list ) or \
416+ isinstance (sample ['RandomCrop_Param' ], tuple )):
417+ params = json .loads (sample ['RandomCrop_Param' ][0 ])
418+ else :
419+ params = json .loads (sample ['RandomCrop_Param' ])
372420 origin_shape = params [0 ]
373421 crop_min = params [1 ]
374422 crop_max = params [2 ]
@@ -416,27 +464,40 @@ class ChannelWiseNormalize(object):
416464 mean (None or tuple/list): The mean values along each channel.
417465 std (None or tuple/list): The std values along each channel.
418466 if mean and std are None, calculate them from non-zero region
419- zero_to_random (bool): If true, replace zero values with random values.
467+ chns (None, or tuple/list): The list of channel indices
468+ zero_to_random (bool, or tuple/list or bool): indicate whether zero values
469+ in each channel is replaced with random values.
420470 """
421- def __init__ (self , mean , std , zero_to_random = False , inverse = False ):
471+ def __init__ (self , mean , std , chns = None , zero_to_random = False , inverse = False ):
422472 self .mean = mean
423473 self .std = std
474+ self .chns = chns
424475 self .zero_to_random = zero_to_random
425476 self .inverse = inverse
426477
427478 def __call__ (self , sample ):
428479 image = sample ['image' ]
429480 mask = image [0 ] > 0
430- for chn in range (image .shape [0 ]):
481+ chns = self .chns
482+ if (chns is None ):
483+ chns = range (image .shape [0 ])
484+ zero_to_random = self .zero_to_random
485+ if (isinstance (zero_to_random , bool )):
486+ zero_to_random = [zero_to_random ]* len (chns )
487+ if (not (self .mean is None and self .std is None )):
488+ assert (len (self .mean ) == len (self .std ))
489+ assert (len (self .mean ) == len (chns ))
490+ for i in range (len (chns )):
491+ chn = chns [i ]
431492 if (self .mean is None and self .std is None ):
432493 pixels = image [chn ][mask > 0 ]
433494 chn_mean = pixels .mean ()
434495 chn_std = pixels .std ()
435496 else :
436- chn_mean = self .mean [chn ]
437- chn_std = self .std [chn ]
497+ chn_mean = self .mean [i ]
498+ chn_std = self .std [i ]
438499 chn_norm = (image [chn ] - chn_mean )/ chn_std
439- if (self . zero_to_random ):
500+ if (zero_to_random [ i ] ):
440501 chn_random = np .random .normal (0 , 1 , size = chn_norm .shape )
441502 chn_norm [mask == 0 ] = chn_random [mask == 0 ]
442503 image [chn ] = chn_norm
@@ -490,6 +551,49 @@ def __call__(self, sample):
490551 def inverse_transform_for_prediction (self , sample ):
491552 raise (ValueError ("not implemented" ))
492553
554+ class LabelToProbability (object ):
555+ """
556+ Convert one-channel label map to multi-channel probability map
557+ Args:
558+ class_num (int): the class number in the label map
559+ """
560+ def __init__ (self , class_num , inverse = False ):
561+ self .class_num = class_num
562+ self .inverse = inverse
563+
564+ def __call__ (self , sample ):
565+ label = sample ['label' ][0 ]
566+ label_prob = []
567+ for i in range (self .class_num ):
568+ temp_prob = label == i * np .ones_like (label )
569+ label_prob .append (temp_prob )
570+ label_prob = np .asarray (label_prob , np .float32 )
571+
572+ sample ['label_prob' ] = label_prob
573+ return sample
574+
575+ def inverse_transform_for_prediction (self , sample ):
576+ raise (ValueError ("not implemented" ))
577+
578+ class ProbabilityToDistance (object ):
579+ """
580+ get distance transform for each label
581+ """
582+ def __init__ (self , inverse = False ):
583+ self .inverse = inverse
584+
585+
586+ def __call__ (self , sample ):
587+ label_prob = sample ['label_prob' ]
588+ label_distance = []
589+ for i in range (label_prob .shape [0 ]):
590+ temp_lab = label_prob [i ]
591+ temp_dis = get_euclidean_distance (temp_lab , dim = 3 , spacing = [1.0 , 1.0 , 1.0 ])
592+ label_distance .append (temp_dis )
593+ label_distance = np .asarray (label_distance )
594+ sample ['label_distance' ] = label_distance
595+ return sample
596+
493597class RegionSwop (object ):
494598 """
495599 Swop a subregion randomly between two images and their corresponding label
@@ -567,8 +671,9 @@ def get_transform(name, params):
567671
568672 elif (name == "Pad" ):
569673 output_size = params ["Pad_output_size" .lower ()]
674+ ceil_mode = params ["Pad_ceil_mode" .lower ()]
570675 inverse = params ["Pad_inverse" .lower ()]
571- return Pad (output_size , inverse )
676+ return Pad (output_size , ceil_mode , inverse )
572677
573678 elif (name == "ChannelWiseGammaCorrection" ):
574679 gamma_min = params ['ChannelWiseGammaCorrection_gamma_min' .lower ()]
@@ -577,11 +682,12 @@ def get_transform(name, params):
577682 return ChannelWiseGammaCorrection (gamma_min , gamma_max , inverse )
578683
579684 elif (name == 'ChannelWiseNormalize' ):
685+ chns = params ['ChannelWiseNormalize_channels' .lower ()]
580686 mean = params ['ChannelWiseNormalize_mean' .lower ()]
581687 std = params ['ChannelWiseNormalize_std' .lower ()]
582688 zero_to_random = params ['ChannelWiseNormalize_zero_to_random' .lower ()]
583689 inverse = params ['ChannelWiseNormalize_inverse' .lower ()]
584- return ChannelWiseNormalize (mean , std , zero_to_random , inverse )
690+ return ChannelWiseNormalize (mean , std , chns , zero_to_random , inverse )
585691
586692 elif (name == 'ChannelWiseThreshold' ):
587693 threshold = params ['ChannelWiseThreshold_threshold' .lower ()]
@@ -594,10 +700,22 @@ def get_transform(name, params):
594700 inverse = params ['LabelConvert_inverse' .lower ()]
595701 return LabelConvert (source_list , target_list , inverse )
596702
703+ elif (name == 'LabelToProbability' ):
704+ class_num = params ['LabelToProbability_class_num' .lower ()]
705+ inverse = params ['LabelToProbability_inverse' .lower ()]
706+ return LabelToProbability (class_num , inverse )
707+
708+ elif (name == 'ProbabilityToDistance' ):
709+ inverse = params ['ProbabilityToDistance_inverse' .lower ()]
710+ return ProbabilityToDistance (inverse )
711+
597712 elif (name == 'RandomCrop' ):
598713 output_size = params ['RandomCrop_output_size' .lower ()]
714+ fg_focus = params ['RandomCrop_foreground_focus' .lower ()]
715+ fg_ratio = params ['RandomCrop_foreground_ratio' .lower ()]
716+ mask_label = params ['RandomCrop_mask_label' .lower ()]
599717 inverse = params ['RandomCrop_inverse' .lower ()]
600- return RandomCrop (output_size , inverse )
718+ return RandomCrop (output_size , fg_focus , fg_ratio , mask_label , inverse )
601719
602720 elif (name == 'RegionSwop' ):
603721 spatial_axes = params ['RegionSwop_spatial_axes' .lower ()]
0 commit comments