在Tensorflow 2.0中的自定义训练循环中应用回调 [英] Applying callbacks in a custom training loop in Tensorflow 2.0
问题描述
我正在使用Tensorflow DCGAN实施指南中提供的代码编写自定义训练循环.我想在训练循环中添加回调.在Keras中,我知道我们将它们作为"fit"方法的参数传递,但是找不到有关如何在自定义训练循环中使用这些回调的资源.我正在从Tensorflow文档中添加自定义训练循环的代码:
I'm writing a custom training loop using the code provided in the Tensorflow DCGAN implementation guide. I wanted to add callbacks in the training loop. In Keras I know we pass them as an argument to the 'fit' method, but can't find resources on how to use these callbacks in the custom training loop. I'm adding the code for the custom training loop from the Tensorflow documentation:
# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
推荐答案
最简单的方法是检查损失在预期期间是否发生了变化,如果没有,则中断或操纵培训过程. 这是您可以实现自定义的提早停止回调的一种方法:
The simplest way would be to check if the loss has changed over your expected period and break or manipulate the training process if not. Here is one way you could implement a custom early stopping callback :
def Callback_EarlyStopping(LossList, min_delta=0.1, patience=20):
#No early stopping for 2*patience epochs
if len(LossList)//patience < 2 :
return False
#Mean loss for last patience epochs and second-last patience epochs
mean_previous = np.mean(LossList[::-1][patience:2*patience]) #second-last
mean_recent = np.mean(LossList[::-1][:patience]) #last
#you can use relative or absolute change
delta_abs = np.abs(mean_recent - mean_previous) #abs change
delta_abs = np.abs(delta_abs / mean_previous) # relative change
if delta_abs < min_delta :
print("*CB_ES* Loss didn't change much from last %d epochs"%(patience))
print("*CB_ES* Percent change in loss value:", delta_abs*1e2)
return True
else:
return False
此Callback_EarlyStopping
在每个时期检查您的指标/损失,如果相对变化小于通过在每个patience
时期之后计算损失的移动平均值而得到的期望值,则返回True
.然后,您可以捕获此True
信号并中断训练循环.要完全回答您的问题,在示例训练循环中,您可以将其用作:
This Callback_EarlyStopping
checks your metrics/loss every epoch and returns True
if the relative change is less than what you expected by computing moving average of losses after every patience
number of epochs. You can then capture this True
signal and break the training loop. To completely answer your question, within your sample training loop you can use this as:
gen_loss_seq = []
for epoch in range(epochs):
#in your example, make sure your train_step returns gen_loss
gen_loss = train_step(dataset)
#ideally, you can have a validation_step and get gen_valid_loss
gen_loss_seq.append(gen_loss)
#check every 20 epochs and stop if gen_valid_loss doesn't change by 10%
stopEarly = Callback_EarlyStopping(gen_loss_seq, min_delta=0.1, patience=20)
if stopEarly:
print("Callback_EarlyStopping signal received at epoch= %d/%d"%(epoch,epochs))
print("Terminating training ")
break
当然,您可以通过多种方式来增加复杂性,例如,您要跟踪哪些损失或指标,您对特定时期的损失的兴趣或损失的移动平均值,您对相对或绝对变化的兴趣您可以参考tf.keras.callbacks.EarlyStopping
的Tensorflow 2.x实现
Of course, you can increase the complexity in numerous ways, for example, which loss or metrics you would like to track, your interest in the loss at a particular epoch or moving average of loss, your interest in relative or absolute change in value, etc. You can refer to Tensorflow 2.x implementation of tf.keras.callbacks.EarlyStopping
here which is generally used in the popular tf.keras.Model.fit
method.
这篇关于在Tensorflow 2.0中的自定义训练循环中应用回调的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!