实现多类骰子损失函数 [英] Implementing Multiclass Dice Loss Function

查看:40
本文介绍了实现多类骰子损失函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 UNet 进行多类分割.我对模型的输入是 HxWxC 而我的输出是,

I am doing multi class segmentation using UNet. My input to the model is HxWxC and my output is,

outputs = layers.Conv2D(n_classes, (1, 1), activation='sigmoid')(decoder0)

使用 SparseCategoricalCrossentropy 我可以很好地训练网络.现在我还想尝试将骰子系数作为损失函数.实现如下,

Using SparseCategoricalCrossentropy I can train the network fine. Now I would like to also try dice coefficient as the loss function. Implemented as follows,

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.math.sigmoid(y_pred)

    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth
    denominator = tf.reduce_sum(y_true + y_pred) + smooth

    return 1 - numerator / denominator

然而,我实际上得到了增加的损失而不是减少的损失.我检查了多个来源,但我发现的所有材料都使用 diceloss 进行二元分类而不是多类.所以我的问题是实施是否存在问题.

However I am actually getting an increasing loss instead of decreasing loss. I have checked multiple sources but all the material I find use diceloss for binary classification and not multiclass. So my question is is is there a problem with the implementation.

推荐答案

问题是你的骰子损失并没有解决你拥有的类的数量,而是假设了二进制情况,所以它可能解释了你损失的增加.

The problem is that your dice loss doesn't address the number of classes you have but rather assumes binary case, so it might explain the increase in your loss.

您应该实现对所有类进行计算的广义骰子损失并返回所有类的值.

You should implement generalized dice loss that accounts for all the classes and return the value for all of them.

类似于以下内容:

def dice_coef_9cat(y_true, y_pred, smooth=1e-7):
    '''
    Dice coefficient for 10 categories. Ignores background pixel label 0
    Pass to model as metric during compile statement
    '''
    y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), num_classes=10)[...,1:])
    y_pred_f = K.flatten(y_pred[...,1:])
    intersect = K.sum(y_true_f * y_pred_f, axis=-1)
    denom = K.sum(y_true_f + y_pred_f, axis=-1)
    return K.mean((2. * intersect / (denom + smooth)))

def dice_coef_9cat_loss(y_true, y_pred):
    '''
    Dice loss to minimize. Pass to model as loss during compile statement
    '''
    return 1 - dice_coef_9cat(y_true, y_pred)

此片段摘自 https://github.com/keras-team/keras/issues/9395#issuecomment-370971561

这是针对 9 个类别的,而您应该根据您拥有的类别数量进行调整.

This is for 9 categories, while you should adjust to the number of categories you have.

这篇关于实现多类骰子损失函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆