如何有条件地将值分配给张量[蒙版损失函数]? [英] How to conditionally assign values to tensor [masking for loss function]?

查看:131
本文介绍了如何有条件地将值分配给张量[蒙版损失函数]?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想创建一个L2损失函数,忽略标签值为0的值(=>像素).张量batch[1]包含标签,而output是净输出的张量,两者都有一个(None,300,300,1)的形状.

I want to create a L2 loss function that ignores values (=> pixels) where the label has the value 0. The tensor batch[1] contains the labels while output is a tensor for the net output, both have a shape of (None,300,300,1).

labels_mask = tf.identity(batch[1])
labels_mask[labels_mask > 0] = 1
loss = tf.reduce_sum(tf.square((output-batch[1])*labels_mask))/tf.reduce_sum(labels_mask)

我当前的代码显示为TypeError: 'Tensor' object does not support item assignment(在第二行).这样做的张量流是什么?我还尝试使用tf.reduce_sum(labels_mask)归一化损失,希望如此.

My current code yields to TypeError: 'Tensor' object does not support item assignment (on the second line). What's the tensorflow-way to do this? I also tried to normalize the loss with tf.reduce_sum(labels_mask), which I hope works like this.

推荐答案

如果要这样编写,则必须使用Tensorflow的scatter方法进行分配.不幸的是,tensorflow也不真正支持布尔索引(新的boolean_select使其成为可能,但很烦人).写起来很棘手,很难看.

If you wanted to write it that way, you would have to use Tensorflow's scatter method for assignment. Unfortunately, tensorflow doesn't really support boolean indexing either (the new boolean_select makes it possible, but annoying). It would be tricky to write and difficult to read.

您有两个烦人的选择:

  1. labels_mask > 0用作布尔掩码,并使用Tensorflow的最新布尔掩码功能.也许这是更张量流的方式,因为它调用了任意特定的函数.
  2. 投射labels_mask > 0以浮动:tf.cast(labels_mask > 0, tf.float32).然后,您可以在代码的最后一行中以所需的方式使用它.
  1. Use labels_mask > 0 as a boolean mask and use Tensorflow's recent boolean_mask function. Maybe this is the more tensorflow way, because it invokes arbitrarily specific functions.
  2. Cast labels_mask > 0 to float: tf.cast(labels_mask > 0, tf.float32). Then, you can use it the way you wanted to in the final line of your code.

这篇关于如何有条件地将值分配给张量[蒙版损失函数]?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆