在 Tensorflow 2.0 的自定义训练循环中应用回调 [英] Applying callbacks in a custom training loop in Tensorflow 2.0

查看:36
本文介绍了在 Tensorflow 2.0 的自定义训练循环中应用回调的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Tensorflow DCGAN 实施指南中提供的代码编写自定义训练循环.我想在训练循环中添加回调.在 Keras 中,我知道我们将它们作为参数传递给 'fit' 方法,但找不到有关如何在自定义训练循环中使用这些回调的资源.我正在从 Tensorflow 文档中添加自定义训练循环的代码:

I'm writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we pass them as an argument to the 'fit' method, but can't find resources on how to use these callbacks in the custom training loop. I'm adding the code for the custom training loop from the Tensorflow documentation:

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

推荐答案

我自己也遇到过这个问题: (1) 我想使用自定义训练循环;(2) 我不想失去 Keras 在回调方面给我的花里胡哨;(3) 我不想自己重新实现它们.Tensorflow 的设计理念是允许开发人员逐渐选择加入其更底层的 API.正如@HyeonPhilYoun 在下面的评论中指出的, 的官方文档tf.keras.callbacks.Callback 给出了我们正在寻找的示例.

I've had this problem myself: (1) I want to use a custom training loop; (2) I don't want to lose the bells and whistles Keras gives me in terms of callbacks; (3) I don't want to re-implement them all myself. Tensorflow has a design philosophy of allowing a developer to gradually opt-in to its more low-level APIs. As @HyeonPhilYoun notes in his comment below, the official documentation for tf.keras.callbacks.Callback gives an example of what we're looking for.

以下对我有用,但可以通过逆向工程 tf.keras.Model 改进.

The following has worked for me, but can be improved by reverse engineering tf.keras.Model.

诀窍是使用 tf.keras.callbacks.CallbackList 然后从您的自定义训练循环中手动触发其生命周期事件.此示例使用 tqdm 来提供有吸引力的进度条,但是 CallbackList 有一个 progress_bar 初始化参数,可以让你使用默认值.training_modeltf.keras.Model 的典型实例.

The trick is to use tf.keras.callbacks.CallbackList and then manually trigger its lifecycle events from within your custom training loop. This example uses tqdm to give attractive progress bars, but CallbackList has a progress_bar initialization argument that can let you use the defaults. training_model is a typical instance of tf.keras.Model.

from tqdm.notebook import tqdm, trange

# Populate with typical keras callbacks
_callbacks = []

callbacks = tf.keras.callbacks.CallbackList(
    _callbacks, add_history=True, model=training_model)

logs = {}
callbacks.on_train_begin(logs=logs)

# Presentation
epochs = trange(
    max_epochs,
    desc="Epoch",
    unit="Epoch",
    postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}")
epochs.set_postfix(loss=0, accuracy=0)

# Get a stable test set so epoch results are comparable
test_batches = batches(test_x, test_Y)

for epoch in epochs:
    callbacks.on_epoch_begin(epoch, logs=logs)

    # I like to formulate new batches each epoch
    # if there are data augmentation methods in play
    training_batches = batches(x, Y)

    # Presentation
    enumerated_batches = tqdm(
        enumerate(training_batches),
        desc="Batch",
        unit="batch",
        postfix="loss = {loss:.4f}, accuracy = {accuracy:.4f}",
        position=1,
        leave=False)

    for (batch, (x, y)) in enumerated_batches:
        training_model.reset_states()
        
        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_train_batch_begin(batch, logs=logs)
        
        logs = training_model.train_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_train_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

        # Presentation
        enumerated_batches.set_postfix(
            loss=float(logs["loss"]),
            accuracy=float(logs["accuracy"]))

    for (batch, (x, y)) in enumerate(test_batches):
        training_model.reset_states()

        callbacks.on_batch_begin(batch, logs=logs)
        callbacks.on_test_batch_begin(batch, logs=logs)

        logs = training_model.test_on_batch(x=x, y=Y, return_dict=True)

        callbacks.on_test_batch_end(batch, logs=logs)
        callbacks.on_batch_end(batch, logs=logs)

    # Presentation
    epochs.set_postfix(
        loss=float(logs["loss"]),
        accuracy=float(logs["accuracy"]))

    callbacks.on_epoch_end(epoch, logs=logs)

    # NOTE: This is a decent place to check on your early stopping
    # callback.
    # Example: use training_model.stop_training to check for early stopping


callbacks.on_train_end(logs=logs)

# Fetch the history object we normally get from keras.fit
history_object = None
for cb in callbacks:
    if isinstance(cb, tf.keras.callbacks.History):
        history_object = cb
assert history_object is not None

这篇关于在 Tensorflow 2.0 的自定义训练循环中应用回调的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆