自定义损失函数中的张量索引 [英] Tensor indexing in custom loss function

查看:154
本文介绍了自定义损失函数中的张量索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

基本上,我希望我的自定义损失函数在常规MSE和从不同索引中减去值的自定义MSE之间交替显示.

Basically, I want my custom loss function to alternate between usual MSE and a custom MSE that subtract values from different indexes.

为澄清起见,假设我有一个y_pred张量为[1、2、4、5]和一个y_true张量为[2、5、1、3].在通常的MSE中,我们应该获得:

To clarify, let's say I have a y_pred tensor that is [1, 2, 4, 5] and a y_true tensor that is [2, 5, 1, 3]. In usual MSE, we should get:

return K.mean(K.squared(y_pred - y_true))

这将执行以下操作:

[1、2、4、5]-[2、5、1、3] = [-1,-3、3、2]

[1, 2, 4, 5] - [2, 5, 1, 3] = [-1, -3, 3, 2]

[-1,-3,3,2]²= [1,9,9,4]

[-1, -3, 3, 2]² = [1, 9, 9, 4]

平均值([1,9,9,4])= 5.75

mean([1, 9, 9, 4]) = 5.75

我需要我的自定义损失函数来选择此均值和从y_pred张量切换索引1和3的平均值之间的最小值,即:

I need my custom loss function to select the minimum value between this mean and other that switches indexes 1 and 3 from the y_pred tensor, i.e.:

[1,5,4,2]-[2,5,1,3] = [-1,0,3,1]

[1, 5, 4, 2] - [2, 5, 1, 3] = [-1, 0, 3, 1]

[-1,0,3,1]²= [1,0,9,1]

[-1, 0, 3, 1]² = [1, 0, 9, 1]

平均值([1,0,9,1])= 2.75

mean([1, 0, 9, 1]) = 2.75

因此,我的自定义损失将返回2.75,这是两次均值之间的最小值.为此,我尝试在numpy数组中转换y_true和y_pred张量,并执行与以下所有相关的数学运算:

So, my custom loss would return 2.75, which is the minimum value between both means. To do this, I tried to transform y_true and y_pred tensors in numpy arrays, doing all math related as following:

def new_mse(y_true, y_pred):
    sess = tf.Session()

    with sess.as_default():
        np_y_true = y_true.eval()
        np_y_pred = y_pred.eval()

        np_err_mse = np.empty(np_y_true.shape)
        np_err_mse = np.square(np_y_pred - np_y_true)

        np_err_new_mse = np.empty(np_y_true.shape)
        l0 = np.square(np_y_pred[:, 2] - np_y_true[:, 0])   
        l1 = np.square(np_y_pred[:, 3] - np_y_true[:, 1])
        l2 = np.square(np_y_pred[:, 0] - np_y_true[:, 2])
        l3 = np.square(np_y_pred[:, 1] - np_y_true[:, 3])   
        l4 = np.square(np_y_pred[:, 4] - np_y_true[:, 4])
        l5 = np.square(np_y_pred[:, 5] - np_y_true[:, 5])
        np_err_new_mse = np.transpose(np.vstack(l0, l1, l2, l3, l4, l5))

        np_err_mse = np.mean(np_err_mse)
        np_err_new_mse = np.mean(np_err_new_mse)

        return np.amin([np_err_mse, np_err_new_mse])

问题是我不能使用带有y_true和y_pred张量的eval()方法,不确定为什么.最后,我的问题是:

Problem is that I can't use eval() method with y_true and y_pred tensors, not sure why. Finally, my questions are:

  1. 是否有一种更简单的方法来处理张量和损失函数中的索引?我是Tensorflow和Keras的新手,我坚信以numpy数组转换所有内容根本不是最佳方法.
  2. 与问题并不完全相关,但是当我尝试使用K.shape(y_true)打印y_true张量的形状时,我得到了"Tensor("Shape_1:0",shape =(2,),dtype = int32)".这使我感到困惑,因为我正在使用y.shape等于(7032,6),即7032个图像,每个图像带有6个标签.可能与损失函数使用的y和y_pred有关一些误解.

推荐答案

通常只使用后端功能,并且您永远不会尝试了解张量的实际值.

Often you work just with backend functions, and you never try to know the actual values of the tensors.

from keras.losses import mean_square_error

def new_mse(y_true,y_pred): 

    #swapping elements 1 and 3 - concatenate slices of the original tensor
    swapped = K.concatenate([y_pred[:1],y_pred[3:],y_pred[2:3],y_pred[1:2]])
    #actually, if the tensors are shaped like (batchSize,4), use this:
    #swapped = K.concatenate([y_pred[:,:1],y_pred[:,3:],y_pred[:,2:3],Y_pred[:,1:2])

    #losses
    regularLoss = mean_squared_error(y_true,y_pred)
    swappedLoss = mean_squared_error(y_true,swapped)

    #concat them for taking a min value
    concat = K.concatenate([regularLoss,swappedLoss])

    #take the minimum
    return K.min(concat)


所以,对于您的物品:


So, for your items:

  1. 您完全正确.不惜一切代价避免张量操作(损失函数,激活,自定义层等)上的numpy

  1. You're totally right. Avoid numpy at all costs in tensor operations (loss functions, activations, custom layers, etc.)

A K.shape()也是张量.它可能具有形状(2,),因为它有两个值,一个值将是7032,另一个值将是6.但是,只有在评估此张量时,才能看到这些值.在内部损失函数中执行此操作通常不是一个好主意.

A K.shape() is also a tensor. It probably has shape (2,) because it has two values, one value will be 7032, the other value will be 6. But you can only see these values when you eval this tensor. Doing this inside loss functions is often a bad idea.

这篇关于自定义损失函数中的张量索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆