TensorFlow 2 如何在 tf.function 中使用 *args? [英] TensorFlow 2 How to use *args in tf.function?

查看:48
本文介绍了TensorFlow 2 如何在 tf.function 中使用 *args?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

进行了更多测试,但我无法通过以下方式重现该行为:

Did a bit more testing and I can't reproduce the behaviour with:

import tensorflow as tf
import numpy as np

@tf.function
def tf_being_unpythonic(an_input, another_input):
    return an_input + another_input

@tf.function
def example(*inputs, other_args = True):
    return tf_being_unpythonic(*inputs)

class TestClass(tf.keras.Model):
    def __init__(self, a, b):
        super().__init__()
        self.a= a
        self.b = b

    @tf.function
    def call(self, *inps, some_kwarg=False):
        if some_kwarg:
            return self.a(*inps)
        return self.b(*inps)

class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.inps = tf.keras.layers.Flatten()
        self.hl1 = tf.keras.layers.Dense(5)
        self.hl2 = tf.keras.layers.Dense(4)
        self.out = tf.keras.layers.Dense(1)

    @tf.function
    def call(self,observation):
        x = self.inps(observation)
        x = self.hl1(x)
        x = self.hl2(x)
        return self.out(x)


class Model2(Model):
    def __init__(self):
        super().__init__()
        self.prein = tf.keras.layers.Concatenate()

    @tf.function
    def call(self,b,c):
        x = self.prein([b,c])
        return super().call(x)   

am = Model()
pm = Model2()
test = TestClass(am,pm)

a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))

test(a,some_kwarg=True)
test(a,b) 

所以这可能是其他地方的错误.

So it's probably a bug somewhere else.

@tf.function
def call(self, *inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    if target:
        return self._target_network(*inp, training)
    return self._network(*inp, training)

我明白了:

ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []

但是 print(inp) 给出:

But print(inp) gives:

(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,) 

我已经编辑过并且只是未提交的玩具代码,因此无法进一步调查.将问题留在这里,以便没有遇到此问题的每个人都没有可阅读的内容.

I've since edited and was just uncommited toy code so can't investigate further. Will leave the question here so that everyone who doesn't get this issue won't have something to read.

推荐答案

我不认为使用 *args 构造是 tf.function 的好习惯>.如您所见,大多数接受可变数量输入的 TF 函数都使用元组.

I don't think that using a *args construct is a good practice for a tf.function. As you can see, most of the TF functions accepting a variable number of inputs use a tuple.

因此,您可以将函数签名重写为:

So, you can rewrite your function signature as:

def call(self, inputs, target=False, training=False)

并调用它:

instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])

<小时>

编辑

顺便说一句,在使用 tf.function*args 构造时,我没有看到任何错误:


Edit

By the way, I don't see any error while using tf.function with a *args construct:

import tensorflow as tf

@tf.function
def call(*inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    return inp[0]

def main():
    print(call(1))
    print(call(2, 2))
    print(call(3, 3, 3))


if __name__ == '__main__':
    main()

tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

因此,您应该向我们提供有关您尝试执行的操作以及错误所在的更多信息.

So you should provide us more informations about what you try to do and where the error is.

这篇关于TensorFlow 2 如何在 tf.function 中使用 *args?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆