TensorFlow:如何根据纪元设置学习率衰减? [英] TensorFlow: How to set learning rate decay based on epochs?

查看:45
本文介绍了TensorFlow:如何根据纪元设置学习率衰减?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

学习率衰减函数tf.train.exponential_decay 接受一个decay_steps 参数.要每 num_epochs 降低学习率,您需要设置 decay_steps = num_epochs * num_train_examples/batch_size.但是,当从 .tfrecords 文件中读取数据时,您不知道其中有多少训练示例.

The learning rate decay function tf.train.exponential_decay takes a decay_steps parameter. To decrease the learning rate every num_epochs, you would set decay_steps = num_epochs * num_train_examples / batch_size. However, when reading data from .tfrecords files, you don't know how many training examples there are inside them.

要获得 num_train_examples,您可以:

  • 使用 num_epochs=1 设置 tf.string_input_producer.
  • 通过 tf.TFRecordReader/tf.parse_single_example 运行它.
  • 循环并计算它在停止前产生输出的次数.
  • Set up a tf.string_input_producer with num_epochs=1.
  • Run this through tf.TFRecordReader/tf.parse_single_example.
  • Loop and count how many times it produces some output before stopping.

然而,这不是很优雅.

是否有更简单的方法可以从 .tfrecords 文件中获取训练示例的数量,或者基于 epochs 而不是步骤设置学习率衰减?

Is there an easier way to either get the number of training examples from a .tfrecords file or set the learning rate decay based on epochs instead of steps?

推荐答案

您可以使用以下代码获取.tfrecords文件中的记录数:

You can use the following code to get the number of records in a .tfrecords file :

def get_num_records(tf_record_file):
  return len([x for x in tf.python_io.tf_record_iterator(tf_record_file)])

这篇关于TensorFlow:如何根据纪元设置学习率衰减?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆