如何使用使用 tf.where() 忽略某些元素的 MSE 创建损失函数 [英] How To Create A Loss Function with MSE that Uses tf.where() to ignore certain elements

查看:46
本文介绍了如何使用使用 tf.where() 忽略某些元素的 MSE 创建损失函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是目前的功能.在这里,它从 MSE 中删除 y_true 小于阈值(此处为 0.1)的任何值.

Here is the function currently. Here, it removes from the MSE any values where y_true is less than a threshold (here, it is 0.1).

def my_loss(y_true,y_pred):
    loss = tf.square(y_true-y_pred)
    # if any y_true is less than a threshold (say 0.1) 
    # the element is removed from loss, and does not affect MSE
    loss = tf.where(y_true<0.1)
    # return mean of losses
    return tf.reduce_mean(loss)

这个可以编译,但是网络从来没有学会很好地预测 0.相反,我只想消除那些 y_true 和 y_pred 都小于某个阈值的值.这是因为它需要先学习如何预测 0,然后再在训练中忽略这些点.

This one compiles, but the network doesn't ever learn to predict 0 well. Instead, I would like to eliminate only those values where both y_true and y_pred are less than some threshold. This is because it needs to first learn how to predict 0, before ignoring those points later on in the training.

然而,这不会编译.

def my_better_loss(y_true,y_pred):
    loss = tf.square(y_true-y_pred)
    # remove all elements where BOTH y_true & y_pred < threshold
    loss = tf.where(y_true<0.1 and y_pred<0.1)
    # return mean of losses
    return tf.reduce_mean(loss)

它导致以下错误.


  (0) Invalid argument:  The second input must be a scalar, but it has shape [25,60,60]
         [[{{node replica_1/customMSE/cond/switch_pred/_51}}]]
  (1) Invalid argument:  The second input must be a scalar, but it has shape [25,60,60]
         [[{{node replica_1/customMSE/cond/switch_pred/_51}}]]
         [[customMSE/cond/Squeeze/_59]]
  (2) Invalid argument:  The second input must be a scalar, but it has shape [25,60,60]
         [[{{node replica_1/customMSE/cond/replica_1/customMSE/Less/_55}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_4715]

Function call stack:
train_function -> train_function -> train_function

更具体地说.假设我们的阈值为 0.5:

To be more specific. Say our threshold is 0.5:

y_true = [0.3, 0.4, 0.6, 0.7]
y_pred = [0.2, 0.7, 0.5, 1]

然后损失函数将计算 mse 并删除第一个元素,因为 y_pred[0] 和 y_true[0] 都小于阈值.

Then the loss function would compute mse with the first element removed, since both y_pred[0] and y_true[0] are less than threshold.

# MSE would be computed between
y_true = [0.4, 0.6, 0.7]
#and
y_pred = [0.7, 0.5, 1]

推荐答案

如果在转换为图形模式,因为python短路运算符不能重载.要对张量进行元素和运算,请使用 tf.math.logical_and.

Most of the time it results in undesirable behaviours or errors if you use the python short-circuit and operator in codes that convert into graph mode because the python short-circuit and operator cannot be overloaded. To do element-wise and operation for tensors, use tf.math.logical_and.

此外,tf.where 在这里不是必需的,它可能会更慢.掩蔽是优选的.示例代码:

Besides, tf.where is not necessary here and it is likely to be slower. Masking is preferred. Example codes:

@tf.function
def better_loss(y_true,y_pred):
  loss = tf.square(y_true - y_pred)
  # ignore elements where BOTH y_true & y_pred < 0.1
  mask = tf.cast(tf.logical_or(y_true >= 0.1, y_pred >= 0.1) ,tf.float32)
  loss *= mask
  return tf.reduce_sum(loss) / tf.reduce_sum(mask)

这篇关于如何使用使用 tf.where() 忽略某些元素的 MSE 创建损失函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆