Skip to content

Commit 50f8194

Browse files
committed
set default parameters
1 parent aca5935 commit 50f8194

12 files changed

Lines changed: 30 additions & 30 deletions

File tree

pymic/loss/seg/ce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
class CrossEntropyLoss(nn.Module):
99
def __init__(self, params):
1010
super(CrossEntropyLoss, self).__init__()
11-
self.enable_pix_weight = params['CrossEntropyLoss_Enable_Pixel_Weight'.lower()]
12-
self.enable_cls_weight = params['CrossEntropyLoss_Enable_Class_Weight'.lower()]
11+
self.enable_pix_weight = params.get('CrossEntropyLoss_Enable_Pixel_Weight'.lower(), False)
12+
self.enable_cls_weight = params.get('CrossEntropyLoss_Enable_Class_Weight'.lower(), False)
1313

1414
def forward(self, loss_input_dict):
1515
predict = loss_input_dict['prediction']

pymic/net_run/agent_seg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,9 @@ def infer(self):
364364
device = torch.device("cuda:{0:}".format(device_ids[0]))
365365
self.net.to(device)
366366

367-
if(self.config['testing']['evaluation_mode'] == True):
367+
if(self.config['testing'].get('evaluation_mode', True)):
368368
self.net.eval()
369-
if(self.config['testing']['test_time_dropout'] == True):
369+
if(self.config['testing'].get('test_time_dropout', False)):
370370
def test_time_dropout(m):
371371
if(type(m) == nn.Dropout):
372372
print('dropout layer')
@@ -432,7 +432,7 @@ def infer_with_multiple_checkpoints(self):
432432
"""
433433
device_ids = self.config['testing']['gpus']
434434
device = torch.device("cuda:{0:}".format(device_ids[0]))
435-
435+
436436
ckpt_names = self.config['testing']['ckpt_name']
437437
infer_cfg = self.config['testing']
438438
infer_cfg['class_num'] = self.config['network']['class_num']

pymic/transform/crop.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, params):
2222
If D is None, then the z-axis is not cropped
2323
"""
2424
self.output_size = params['CenterCrop_output_size'.lower()]
25-
self.inverse = params['CenterCrop_inverse'.lower()]
25+
self.inverse = params.get('CenterCrop_inverse'.lower(), True)
2626
self.task = params['Task'.lower()]
2727

2828
def get_crop_param(self, sample):
@@ -113,7 +113,7 @@ def __init__(self, params):
113113
"""
114114
self.start = params['CropWithBoundingBox_start'.lower()]
115115
self.output_size = params['CropWithBoundingBox_output_size'.lower()]
116-
self.inverse = params['CropWithBoundingBox_inverse'.lower()]
116+
self.inverse = params.get('CropWithBoundingBox_inverse'.lower(), True)
117117
self.task = params['task']
118118

119119
def get_crop_param(self, sample):
@@ -170,10 +170,10 @@ def __init__(self, params):
170170
"""
171171
# super(RandomCrop, self).__init__(params)
172172
self.output_size = params['RandomCrop_output_size'.lower()]
173-
self.fg_focus = params['RandomCrop_foreground_focus'.lower()]
174-
self.fg_ratio = params['RandomCrop_foreground_ratio'.lower()]
175-
self.mask_label = params['RandomCrop_mask_label'.lower()]
176-
self.inverse = params['RandomCrop_inverse'.lower()]
173+
self.fg_focus = params.get('RandomCrop_foreground_focus'.lower(), False)
174+
self.fg_ratio = params.get('RandomCrop_foreground_ratio'.lower(), 0.5)
175+
self.mask_label = params.get('RandomCrop_mask_label'.lower(), [1])
176+
self.inverse = params.get('RandomCrop_inverse'.lower(), True)
177177
self.task = params['Task'.lower()]
178178
assert isinstance(self.output_size, (list, tuple))
179179
if(self.mask_label is not None):
@@ -238,7 +238,7 @@ def __init__(self, params):
238238
self.output_size = params['RandomResizedCrop_output_size'.lower()]
239239
self.scale = params['RandomResizedCrop_scale'.lower()]
240240
self.ratio = params['RandomResizedCrop_ratio'.lower()]
241-
self.inverse = params['RandomResizedCrop_inverse'.lower()]
241+
self.inverse = params.get('RandomResizedCrop_inverse'.lower(), True)
242242
self.task = params['Task'.lower()]
243243
assert isinstance(self.output_size, (list, tuple))
244244
assert isinstance(self.scale, (list, tuple))

pymic/transform/flip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, params):
2323
self.flip_depth = params['RandomFlip_flip_depth'.lower()]
2424
self.flip_height = params['RandomFlip_flip_height'.lower()]
2525
self.flip_width = params['RandomFlip_flip_width'.lower()]
26-
self.inverse = params['RandomFlip_inverse'.lower()]
26+
self.inverse = params.get('RandomFlip_inverse'.lower(), True)
2727

2828
def __call__(self, sample):
2929
image = sample['image']

pymic/transform/gamma_correction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, params):
2222
super(ChannelWiseGammaCorrection, self).__init__(params)
2323
self.gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()]
2424
self.gamma_max = params['ChannelWiseGammaCorrection_gamma_max'.lower()]
25-
self.inverse = params['ChannelWiseGammaCorrection_inverse'.lower()]
25+
self.inverse = params.get('ChannelWiseGammaCorrection_inverse'.lower(), False)
2626

2727
def __call__(self, sample):
2828
image= sample['image']

pymic/transform/gray2rgb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, params):
2020
(gamma_min, gamma_max) specify the range of gamma
2121
"""
2222
super(GrayscaleToRGB, self).__init__(params)
23-
self.inverse = params['GrayscaleToRGB_inverse'.lower()]
23+
self.inverse = params.get('GrayscaleToRGB_inverse'.lower(), False)
2424

2525
def __call__(self, sample):
2626
image= sample['image']

pymic/transform/label_convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ReduceLabelDim(AbstractTransform):
1616
"""
1717
def __init__(self, params):
1818
super(ReduceLabelDim, self).__init__(params)
19-
self.inverse = params['ReduceLabelDim_inverse'.lower()]
19+
self.inverse = params.get('ReduceLabelDim_inverse'.lower(), False)
2020

2121
def __call__(self, sample):
2222
label = sample['label']
@@ -34,7 +34,7 @@ def __init__(self, params):
3434
super(LabelConvert, self).__init__(params)
3535
self.source_list = params['LabelConvert_source_list'.lower()]
3636
self.target_list = params['LabelConvert_target_list'.lower()]
37-
self.inverse = params['LabelConvert_inverse'.lower()]
37+
self.inverse = params.get('LabelConvert_inverse'.lower(), False)
3838
assert(len(self.source_list) == len(self.target_list))
3939

4040
def __call__(self, sample):
@@ -47,7 +47,7 @@ class LabelConvertNonzero(AbstractTransform):
4747
""" Convert label into binary (nonzero as 1)"""
4848
def __init__(self, params):
4949
super(LabelConvertNonzero, self).__init__(params)
50-
self.inverse = params['LabelConvertNonzero_inverse'.lower()]
50+
self.inverse = params.get('LabelConvertNonzero_inverse'.lower(), False)
5151

5252
def __call__(self, sample):
5353
label = sample['label']
@@ -63,7 +63,7 @@ def __init__(self, params):
6363
"""
6464
super(LabelToProbability, self).__init__(params)
6565
self.class_num = params['LabelToProbability_class_num'.lower()]
66-
self.inverse = params['LabelToProbability_inverse'.lower()]
66+
self.inverse = params.get('LabelToProbability_inverse'.lower(), False)
6767

6868
def __call__(self, sample):
6969
if(self.task == 'segmentation'):

pymic/transform/normalize.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def __init__(self, params):
2323
"""
2424
super(NormalizeWithMeanStd, self).__init__(params)
2525
self.chns = params['NormalizeWithMeanStd_channels'.lower()]
26-
self.mean = params['NormalizeWithMeanStd_mean'.lower()]
27-
self.std = params['NormalizeWithMeanStd_std'.lower()]
28-
self.ingore_np = params['NormalizeWithMeanStd_ignore_non_positive'.lower()]
29-
self.inverse = params['NormalizeWithMeanStd_inverse'.lower()]
26+
self.mean = params.get('NormalizeWithMeanStd_mean'.lower(), None)
27+
self.std = params.get('NormalizeWithMeanStd_std'.lower(), None)
28+
self.ingore_np = params.get('NormalizeWithMeanStd_ignore_non_positive'.lower(), False)
29+
self.inverse = params.get('NormalizeWithMeanStd_inverse'.lower(), False)
3030

3131
def __call__(self, sample):
3232
image= sample['image']
@@ -69,7 +69,7 @@ def __init__(self, params):
6969
self.chns = params['NormalizeWithMinMax_channels'.lower()]
7070
self.thred_lower = params['NormalizeWithMinMax_threshold_lower'.lower()]
7171
self.thred_upper = params['NormalizeWithMinMax_threshold_upper'.lower()]
72-
self.inverse = params['NormalizeWithMinMax_inverse'.lower()]
72+
self.inverse = params.get('NormalizeWithMinMax_inverse'.lower(), False)
7373

7474
def __call__(self, sample):
7575
image= sample['image']
@@ -104,7 +104,7 @@ def __init__(self, params):
104104
self.chns = params['NormalizeWithPercentiles_channels'.lower()]
105105
self.percent_lower = params['NormalizeWithPercentiles_percentile_lower'.lower()]
106106
self.percent_upper = params['NormalizeWithPercentiles_percentile_upper'.lower()]
107-
self.inverse = params['NormalizeWithPercentiles_inverse'.lower()]
107+
self.inverse = params.get('NormalizeWithPercentiles_inverse'.lower(), False)
108108

109109
def __call__(self, sample):
110110
image= sample['image']

pymic/transform/pad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, params):
2424
super(Pad, self).__init__(params)
2525
self.output_size = params['Pad_output_size'.lower()]
2626
self.ceil_mode = params['Pad_ceil_mode'.lower()]
27-
self.inverse = params['Pad_inverse'.lower()]
27+
self.inverse = params.get('Pad_inverse'.lower(), True)
2828

2929
def __call__(self, sample):
3030
image = sample['image']

pymic/transform/rescale.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, params):
2424
"""
2525
super(Rescale, self).__init__(params)
2626
self.output_size = params["Rescale_output_size".lower()]
27-
self.inverse = params["Rescale_inverse".lower()]
27+
self.inverse = params.get("Rescale_inverse".lower(), True)
2828
assert isinstance(self.output_size, (int, list, tuple))
2929

3030
def __call__(self, sample):
@@ -89,7 +89,7 @@ def __init__(self, params):
8989
super(RandomRescale, self).__init__(params)
9090
self.ratio0 = params["RandomRescale_lower_bound".lower()]
9191
self.ratio1 = params["RandomRescale_upper_bound".lower()]
92-
self.inverse = params["RandomRescale_inverse".lower()]
92+
self.inverse = params.get("RandomRescale_inverse".lower(), True)
9393
assert isinstance(self.ratio0, (float, list, tuple))
9494
assert isinstance(self.ratio1, (float, list, tuple))
9595

0 commit comments

Comments
 (0)