在Tensorflow 2.0中的自定义训练循环中应用回调 [英] Applying callbacks in a custom training loop in Tensorflow 2.0

查看:383
本文介绍了在Tensorflow 2.0中的自定义训练循环中应用回调的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用Tensorflow DCGAN实施指南中提供的代码编写自定义训练循环.我想在训练循环中添加回调.在Keras中,我知道我们将它们作为"fit"方法的参数传递,但是找不到有关如何在自定义训练循环中使用这些回调的资源.我正在从Tensorflow文档中添加自定义训练循环的代码:

I'm writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we pass them as an argument to the 'fit' method, but can't find resources on how to use these callbacks in the custom training loop. I'm adding the code for the custom training loop from the Tensorflow documentation:

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

推荐答案

最简单的方法是检查损失在预期期间是否发生了变化,如果没有,则中断或操纵培训过程. 这是您可以实现自定义的提早停止回调的一种方法:

The simplest way would be to check if the loss has changed over your expected period and break or manipulate the training process if not. Here is one way you could implement a custom early stopping callback :

def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
    #No early stopping for 2*patience epochs 
    if len(LossList)//patience < 2 :
        return False
    #Mean loss for last patience epochs and second-last patience epochs
    mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
    mean_recent = np.mean(LossList[::-1][:patience]) #last
    #you can use relative or absolute change
    delta_abs = np.abs(mean_recent - mean_previous) #abs change
    delta_abs = np.abs(delta_abs / mean_previous)  # relative change
    if delta_abs < min_delta :
        print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
        print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
        return True
    else:
        return False

Callback_EarlyStopping在每个时期检查您的指标/损失,如果相对变化小于通过在每个patience时期之后计算损失的移动平均值而得到的期望值,则返回True.然后,您可以捕获此True信号并中断训练循环.要完全回答您的问题,在示例训练循环中,您可以将其用作:

This Callback_EarlyStopping checks your metrics/loss every epoch and returns True if the relative change is less than what you expected by computing moving average of losses after every patience number of epochs. You can then capture this True signal and break the training loop. To completely answer your question, within your sample training loop you can use this as:

gen_loss_seq = []
for epoch in range(epochs):
  #in your example, make sure your train_step returns gen_loss
  gen_loss = train_step(dataset) 
  #ideally, you can have a validation_step and get gen_valid_loss
  gen_loss_seq.append(gen_loss)  
  #check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
  stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
  if stopEarly:
    print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
    print("Terminating training ")
    break
       

当然,您可以通过多种方式来增加复杂性,例如,您要跟踪哪些损失或指标,您对特定时期的损失的兴趣或损失的移动平均值,您对相对或绝对变化的兴趣您可以参考tf.keras.callbacks.EarlyStopping的Tensorflow 2.x实现

Of course, you can increase the complexity in numerous ways, for example, which loss or metrics you would like to track, your interest in the loss at a particular epoch or moving average of loss, your interest in relative or absolute change in value, etc. You can refer to Tensorflow 2.x implementation of tf.keras.callbacks.EarlyStopping here which is generally used in the popular tf.keras.Model.fit method.

这篇关于在Tensorflow 2.0中的自定义训练循环中应用回调的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆