加速Tensorflow 2.0渐变胶带 [英] Speeding up Tensorflow 2.0 Gradient Tape

查看:65
本文介绍了加速Tensorflow 2.0渐变胶带的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在关注卷积VAE的TF 2.0教程,位于此处.

I have been following the TF 2.0 tutorial for convolution VAE's, located here.

由于急切,请手动计算渐变,然后使用tf.GradientTape()手动应用渐变.

Since it is eager, the gradients are computed by hand and then applied manually, using tf.GradientTape().

for epoch in epochs:
  for x in x_train:
    with tf.GradientTape() as tape:
      loss = compute_loss(model, x)
    apply_gradients(tape.gradient(loss, model.trainable_variables))

该代码的问题在于它相当慢,每个纪元大约需要40-50秒. 如果我将批处理大小增加很多(到2048个左右),则最终每个周期大约需要8秒钟,但是模型的性能却下降了很多.

The problem with that code is that it is pretty slow, taking around 40-50 seconds per epoch. If I increase the batch size by a lot (to around 2048), then it ends up taking around 8 seconds per epoch, but the model's performance decreases by quite a lot.

另一方面,如果我做一个更传统的模型(即使用基于懒惰图的模型而不是急切的模型),例如一个

On the other hand, if I do a more traditional model (i.e., that uses the lazy graph-based model instead of eagerness), such as the one here, then it takes 8 seconds per epoch even with a small batch size.

model.add_loss(lazy_graph_loss)
model.fit(x_train epochs=epochs)

基于此信息,我的猜测是TF2.0代码的问题是损耗和梯度的手动计算.

Based on this information, my guess would be that the problem with the TF2.0 code is the manual computation of losses and gradients.

有什么方法可以加快TF2.0代码的速度,使其更接近普通代码?

Is there any way to speed up the TF2.0 code so that it comes closer to the normal code?

推荐答案

我找到了解决方案:TensorFlow 2.0引入了功能,它将渴望的代码转换为图形代码.

I found the solution: TensorFlow 2.0 introduces the concept of functions, which translate eager code into graph code.

用法非常简单.唯一需要做的更改是所有相关功能(如compute_lossapply_gradients)都必须用@tf.function注释.

The usage is pretty straight-forward. The only change needed is that all relevant functions (like compute_loss and apply_gradients) have to be annotated with @tf.function.

这篇关于加速Tensorflow 2.0渐变胶带的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆