未启用急切执行时,张量对象不可迭代.要迭代这个张量使用 tf.map_fn [英] Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn

查看:39
本文介绍了未启用急切执行时,张量对象不可迭代.要迭代这个张量使用 tf.map_fn的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试创建自己的损失函数:

I am trying to create my own loss function:

def custom_mse(y_true, y_pred):
    tmp = 10000000000
    a = list(itertools.permutations(y_pred))
    for i in range(0, len(a)): 
     t = K.mean(K.square(a[i] - y_true), axis=-1)
     if t < tmp :
        tmp = t
     return tmp

它应该创建预测向量的排列,并返回最小的损失.

It should create permutations of predicted vector, and return the smallest loss.

   "Tensor objects are not iterable when eager execution is not "
TypeError: Tensor objects are not iterable when eager execution is not enabled. To iterate over this tensor use tf.map_fn.

错误.我找不到此错误的任何来源.为什么会发生这种情况?

error. I fail to find any source for this error. Why is this happening?

推荐答案

发生错误是因为 y_pred 是一个张量(非急切执行不可迭代),并且 itertools.permutations 需要一个迭代来创建排列.此外,您计算最小损失的部分也不起作用,因为张量 t 的值在图创建时是未知的.

The error is happening because y_pred is a tensor (non iterable without eager execution), and itertools.permutations expects an iterable to create the permutations from. In addition, the part where you compute the minimum loss would not work either, because the values of tensor t are unknown at graph creation time.

我会创建索引的排列而不是排列张量(这是您可以在创建图形时执行的操作),然后从张量中收集排列的索引.假设你的 Keras 后端是 TensorFlow 并且 y_true/y_pred 是二维的,你的损失函数可以如下实现:

Instead of permuting the tensor, I would create permutations of the indices (this is something you can do at graph creation time), and then gather the permuted indices from the tensor. Assuming that your Keras backend is TensorFlow and that y_true/y_pred are 2-dimensional, your loss function could be implemented as follows:

def custom_mse(y_true, y_pred):
    batch_size, n_elems = y_pred.get_shape()
    idxs = list(itertools.permutations(range(n_elems)))
    permutations = tf.gather(y_pred, idxs, axis=-1)  # Shape=(batch_size, n_permutations, n_elems)
    mse = K.square(permutations - y_true[:, None, :])  # Shape=(batch_size, n_permutations, n_elems)
    mean_mse = K.mean(mse, axis=-1)  # Shape=(batch_size, n_permutations)
    min_mse = K.min(mean_mse, axis=-1)  # Shape=(batch_size,)
    return min_mse

这篇关于未启用急切执行时,张量对象不可迭代.要迭代这个张量使用 tf.map_fn的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆