循环张量并将函数应用于每个元素 [英] Loop over a tensor and apply function to each element

查看:36
本文介绍了循环张量并将函数应用于每个元素的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想遍历一个包含 Int 列表的张量,并对每个元素应用一个函数.在函数中,每个元素都将从 python 的字典中获取值.我已经尝试过使用 tf.map_fn 的简单方法,它适用于 add 函数,例如以下代码:

I want to loop over a tensor which contains a list of Int, and apply a function to each of the elements. In the function every element will get the value from a dict of python. I have tried the easy way with tf.map_fn, which will work on add function, such as the following code:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]

但是下面的代码抛出了 KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32 异常:

But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32 exception:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

我的 tensorflow 版本是 1.13.1.先谢谢了.

My tensorflow version is 1.13.1. Thanks ahead.

推荐答案

有一个简单的方法可以实现,您正在尝试什么.

There is a simple way to achieve, what you are trying.

问题是传递给map_fn的函数必须有张量作为参数,张量作为返回值.但是,您的函数 trans_2 将普通 python int 作为参数并返回另一个 python int.这就是为什么您的代码不起作用的原因.

The problem is that the function passed to map_fn must have tensors as its parameters and tensor as the return value. However, your function trans_2 takes plain python int as parameter and returns another python int. That's why your code doesn't work.

但是,TensorFlow 提供了一种简单的方法来包装普通的 Python 函数,即 tf.py_func,您可以在您的情况下使用它,如下所示:

However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func, you can use it in your case as follows:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

def wrapper(x):
    return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

你可以看到我添加了一个包装函数,它需要张量参数并返回一个张量,这就是它可以在 map_fn 中使用的原因.使用强制转换是因为 Python 默认使用 64 位整数,而 TensorFlow 使用 32 位整数.

you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.

这篇关于循环张量并将函数应用于每个元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆