TensorFlow 2.0 dataset.__iter__() 仅在启用 Eager Execution 时才受支持 [英] TensorFlow 2.0 dataset.__iter__() is only supported when eager execution is enabled

查看:35
本文介绍了TensorFlow 2.0 dataset.__iter__() 仅在启用 Eager Execution 时才受支持的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在 TensorFlow 2 中使用以下自定义训练代码:

I'm using the following custom training code in TensorFlow 2:

def parse_function(filename, filename2):
    image = read_image(fn)
    def ret1(): return image, read_image(fn2), 0
    def ret2(): return image, preprocess(image), 1
    return tf.case({tf.less(tf.random.uniform([1])[0], tf.constant(0.5)): ret2}, default=ret1)

dataset = tf.data.Dataset.from_tensor_slices((train,shuffled_train))
dataset = dataset.shuffle(len(train))
dataset = dataset.map(parse_function, num_parallel_calls=4)
dataset = dataset.batch(1)
dataset = dataset.prefetch(buffer_size=4)

@tf.function
def train(model, dataset, optimizer):
    for x1, x2, y in enumerate(dataset):
        with tf.GradientTape() as tape:
            left, right = model([x1, x2])
            loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

siamese_net.compile(optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-3))
train(siamese_net, dataset, tf.keras.optimizers.RMSprop(learning_rate=1e-3))

这段代码给了我错误:

dataset.__iter__() is only supported when eager execution is enabled.

但是,它在 TensorFlow 2.0 中,因此默认情况下启用了eager.tf.executing_eagerly() 也返回真".

However, it's in TensorFlow 2.0 so eager is enabled by default. tf.executing_eagerly() also returns 'True'.

推荐答案

我通过将 train 函数更改为以下内容来解决此问题:

I fixed this by changing the train function to the following:

def train(model, dataset, optimizer):
    for step, (x1, x2, y) in enumerate(dataset):
        with tf.GradientTape() as tape:
            left, right = model([x1, x2])
            loss = contrastive_loss(left, right, tf.cast(y, tf.float32))
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

两个变化是删除@tf.function 和修复枚举.

The two changes are removing the @tf.function and fixing the enumeration.

这篇关于TensorFlow 2.0 dataset.__iter__() 仅在启用 Eager Execution 时才受支持的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆