在 Tensorflow 2.0 的 tf.function input_signature 中使用字典 [英] Use dictionary in tf.function input_signature in Tensorflow 2.0

查看:58
本文介绍了在 Tensorflow 2.0 的 tf.function input_signature 中使用字典的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Tensorflow 2.0 并面临以下情况:

I am using Tensorflow 2.0 and facing the following situation:

@tf.function
def my_fn(items):
    .... #do stuff
    return

如果 items 是张量的字典,例如:

If items is a dict of Tensors like for example:

item1 = tf.zeros([1, 1])
item2 = tf.zeros(1)
items = {"item1": item1, "item2": item2}

有没有办法使用 tf.function 的 input_signature 参数,以便当 item1 是例如 tf.zeros([2,1]) 时,我可以强制 tf2 避免创建多个图形?

Is there a way of using input_signature argument of tf.function so I can force tf2 to avoid creating multiple graphs when item1 is for example tf.zeros([2,1]) ?

推荐答案

输入签名必须是一个列表,但列表中的元素可以是字典或 Tensor Specs 列表.在你的情况下,我会尝试:(name 属性是可选的)

The input signature has to be a list, but elements in the list can be dictionaries or lists of Tensor Specs. In your case I would try: (the name attributes are optional)

signature_dict = { "item1": tf.TensorSpec(shape=[2], dtype=tf.int32, name="item1"),
                   "item2": tf.TensorSpec(shape=[], dtype=tf.int32, name="item2") } 
              

# don't forget the brackets around the 'signature_dict'
@tf.function(input_signature = [signature_dict])
def my_fn(items):
    .... # do stuff
    return

# calling the TensorFlow function
my_fun(items)

但是,如果要调用由 my_fn 创建的特定具体函数,则必须解压缩字典.您还必须在 tf.TensorSpec 中提供 name 属性.

However, if you want to call a particular concrete function created by my_fn, you have to unpack the dictionary. You also have to provide the name attribute in tf.TensorSpec.

# creating a concrete function with an input signature as before but without
# brackets and with mandatory 'name' attributes in the TensorSpecs 
my_concrete_fn = my_fn.get_concrete_function(signature_dict)
                                             
# calling the concrete function with the unpacking operator
my_concrete_fn(**items)

这很烦人,但应该在 TensorFlow 2.3 中解决.(参见具体函数"的 TF 指南末尾)

This is annoying but should be resolved in TensorFlow 2.3. (see the end of the TF Guide to 'Concrete functions')

这篇关于在 Tensorflow 2.0 的 tf.function input_signature 中使用字典的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆