TensorFlow 2 tf.function 装饰器 [英] TensorFlow 2 tf.function decorator

查看:54
本文介绍了TensorFlow 2 tf.function 装饰器的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有 TensorFlow 2.0 和 Python 3.7.5.

I have TensorFlow 2.0 and Python 3.7.5.

我编写了以下代码来执行小批量梯度下降,即:

I have written the following code for performing mini-batch gradient descent which is:

@tf.function
def train_one_step(model, mask_model, optimizer, x, y):
    '''
    Function to compute one step of gradient descent optimization
    '''
    with tf.GradientTape() as tape:
        # Make predictions using defined model-
        y_pred = model(x)

        # Compute loss-
        loss = loss_fn(y, y_pred)

    # Compute gradients wrt defined loss and weights and biases-
    grads = tape.gradient(loss, model.trainable_variables)

    # type(grads)
    # list

    # List to hold element-wise multiplication between-
    # computed gradient and masks-
    grad_mask_mul = []

    # Perform element-wise multiplication between computed gradients and masks-
    for grad_layer, mask in zip(grads, mask_model.trainable_weights):
        grad_mask_mul.append(tf.math.multiply(grad_layer, mask))

    # Apply computed gradients to model's weights and biases-
    optimizer.apply_gradients(zip(grad_mask_mul, model.trainable_variables))

    # Compute accuracy-
    train_loss(loss)
    train_accuracy(y, y_pred)

    return None

在代码中,mask_model"是一个掩码,它要么是 0,要么是 1.mask_model"的用途是控制训练哪些参数(因为,0 *梯度下降 = 0).

In the code, "mask_model" is a mask which is either 0 or 1. The use of "mask_model" is to control which parameters are trained (since, 0 * gradient descent = 0).

我的问题是,我在train_one_step()"TensorFlow 装饰函数中使用grad_mask_mul"列表变量.这是否会导致任何问题,例如:

My question is, I am using "grad_mask_mul" list variable inside "train_one_step()" TensorFlow decorated function. Can this cause any problems, such as:

ValueError: tf.function-decorated 函数试图创建变量非首次通话.

ValueError: tf.function-decorated function tried to create variables on non-first call.

或者你们是否看到在 tensorflow 装饰函数中使用列表变量的一些问题?

Or do you guys see some problem of using a list variable inside a tensorflow decorated function?

谢谢!

推荐答案

这是 TensorFlow 2 中的一个错误.您可以在此处阅读有关它的更多信息 TF2 错误

this is a bug in TensorFlow 2. You can read more about it here TF2 bug

这篇关于TensorFlow 2 tf.function 装饰器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆