|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +from __future__ import print_function, division |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.nn as nn |
| 6 | +import numpy as np |
| 7 | +from pymic.loss.util import reshape_tensor_to_2D, get_classwise_dice |
| 8 | + |
| 9 | +class MyFocalDiceLoss(nn.Module): |
| 10 | + """ |
| 11 | + Focal Dice loss proposed in the following paper: |
| 12 | + P. Wang et al. Focal dice loss and image dilatin for brain tumor segmentation. |
| 13 | + in Deep Learning in Medical Image Analysis and Multimodal Learning for Clinical |
| 14 | + Decision Support, 2018. |
| 15 | + """ |
| 16 | + def __init__(self, params): |
| 17 | + super(MyFocalDiceLoss, self).__init__() |
| 18 | + self.enable_pix_weight = params['MyFocalDiceLoss_Enable_Pixel_Weight'.lower()] |
| 19 | + self.enable_cls_weight = params['MyFocalDiceLoss_Enable_Class_Weight'.lower()] |
| 20 | + self.beta = params['MyFocalDiceLoss_beta'.lower()] |
| 21 | + assert(self.beta >= 1.0) |
| 22 | + |
| 23 | + def forward(self, loss_input_dict): |
| 24 | + predict = loss_input_dict['prediction'] |
| 25 | + soft_y = loss_input_dict['ground_truth'] |
| 26 | + pix_w = loss_input_dict['pixel_weight'] |
| 27 | + cls_w = loss_input_dict['class_weight'] |
| 28 | + softmax = loss_input_dict['softmax'] |
| 29 | + |
| 30 | + if(softmax): |
| 31 | + predict = nn.Softmax(dim = 1)(predict) |
| 32 | + predict = reshape_tensor_to_2D(predict) |
| 33 | + soft_y = reshape_tensor_to_2D(soft_y) |
| 34 | + |
| 35 | + if(self.enable_pix_weight): |
| 36 | + if(pix_w is None): |
| 37 | + raise ValueError("Pixel weight is enabled but not defined") |
| 38 | + pix_w = reshape_tensor_to_2D(pix_w) |
| 39 | + dice_score = get_classwise_dice(predict, soft_y, pix_w) |
| 40 | + dice_score = 0.01 + dice_score * 0.98 |
| 41 | + dice_loss = 1.0 - torch.pow(dice_score, 1.0 / self.beta) |
| 42 | + |
| 43 | + if(self.enable_cls_weight): |
| 44 | + if(cls_w is None): |
| 45 | + raise ValueError("Class weight is enabled but not defined") |
| 46 | + weighted_loss = dice_loss * cls_w |
| 47 | + avg_loss = weighted_loss.sum() / cls_w.sum() |
| 48 | + else: |
| 49 | + avg_loss = torch.mean(dice_loss) |
| 50 | + return avg_loss |
0 commit comments