真实非零预测的损失损失更高 [英] Higher loss penalty for true non-zero predictions

查看:78
本文介绍了真实非零预测的损失损失更高的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在建立一个深度回归网络(CNN),以根据图像(7,11)预测(1000,1)目标向量。目标通常包含大约 90%的零和仅 10%的非零值。目标中(非)零值的分布因样本而异(即不存在全局类别不平衡)。



使用均方误差损失,这导致到网络仅预测零,这并不奇怪。



我最好的猜测是编写一个自定义损失函数,该函数对与非零值有关的错误的惩罚大于



我尝试了此损失函数,目的是实现我认为可以在上面进行的工作。这是一个均方误差损失,其中对非零目标的预测的惩罚较少(w = 0.1)。

  def my_loss(y_true,y_pred):
#加权真实零预测小于真实非零预测
w = 0.1
y_pred_of_nonzeros = tf.where(tf.equal(y_true,0),y_pred-y_pred, y_pred)
return K.mean(K.square(y_true-y_pred_of_nonzeros))+ K.mean(K.square(y_true-y_pred))* w

该网络能够学习而不会陷入零预测。但是,此解决方案似乎不干净。有没有更好的方法来处理此类问题?关于改善自定义损失功能的任何建议?
任何建议都欢迎,谢谢您!



最好,
Lukas

解决方案

不确定像您所做的那样,有什么比自定义损失更好的方法,但是有一种更干净的方法:

  def加权损失(w):

def损失(true,pred):

error = K.square(true-pred)
错误= K.switch(K.equal(true,0),w *错误,错误)

返回错误

返回损失

您也可以返回K.mean(错误),但不包含平均值,您仍然可以从其他Keras选项(例如添加样本权重和其他内容)中获利。



选择编译时的权重:

  model.compile(loss = weightedLoss(0.1),...)

如果数组中有整个数据,则可以做到:

  w = K.mean(y_train)
w = w /(1- w)#此行弥补缺少类1的90%权重






另一个可以避免使用自定义损失,但需要更改数据和模型的解决方案是:




  • 转换您的 y 转换为每个输出的2类问题。形状= (批次,originalClasses,2)



对于零值,使两个类中的第一个= 1

对于一个值,使两个类中的第二个= 1

  newY = np.stack([1-oldY,oldY],axis = -1 )

调整模型以输出此新形状。

  ... 
model.add(密集(2 * class))
model.add( Reshape((classes,2)))
model.add(Activation('softmax'))

确保您使用的是 softmax categorical_crossentropy 作为损失。



然后在中使用参数 class_weight = {0:w,1:1} 适合


I am building a deep regression network (CNN) to predict a (1000,1) target vector from images (7,11). The target usually consists of about 90 % zeros and only 10 % non-zero values. The distribution of (non-) zero values in the targets vary from sample to sample (i.e. there is no global class imbalance).

Using mean sqaured error loss, this led to the network predicting only zeros, which I don't find surprising.

My best guess is to write a custom loss function that penalizes errors regarding non-zero values more than the prediction of zero-values.

I have tried this loss function with the intend to implement what I have guessed could work above. It is a mean squared error loss in which the predictions of non-zero targets are penalized less (w=0.1).

def my_loss(y_true, y_pred):
    # weights true zero predictions less than true nonzero predictions
    w = 0.1
    y_pred_of_nonzeros = tf.where(tf.equal(y_true, 0), y_pred-y_pred, y_pred)
    return K.mean(K.square(y_true-y_pred_of_nonzeros)) + K.mean(K.square(y_true-y_pred))*w

The network is able to learn without getting stuck with only-zero predictions. However, this solution seems quite unclean. Is there a better way to deal with this type of problem? Any advice on improving the custom loss function? Any suggestions are welcome, thank you in advance!

Best, Lukas

解决方案

Not sure there is anything better than a custom loss just like you did, but there is a cleaner way:

def weightedLoss(w):

    def loss(true, pred):

        error = K.square(true - pred)
        error = K.switch(K.equal(true, 0), w * error , error)

        return error 

    return loss

You may also return K.mean(error), but without mean you can still profit from other Keras options like adding sample weights and other things.

Select the weight when compiling:

model.compile(loss = weightedLoss(0.1), ...)

If you have the entire data in an array, you can do:

w = K.mean(y_train)
w = w / (1 - w) #this line compesates the lack of the 90% weights for class 1


Another solution that can avoid using a custom loss, but requires changes in the data and the model is:

  • Transform your y into a 2-class problem for each output. Shape = (batch, originalClasses, 2).

For the zero values, make the first of the two classes = 1
For the one values, make the second of the two classes = 1

newY = np.stack([1-oldY, oldY], axis=-1)    

Adjust the model to output this new shape.

...
model.add(Dense(2*classes))
model.add(Reshape((classes,2)))
model.add(Activation('softmax'))

Make sure you are using a softmax and a categorical_crossentropy as loss.

Then use the argument class_weight={0: w, 1: 1} in fit.

这篇关于真实非零预测的损失损失更高的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆