TensorFlow:如何根据纪元设置学习率衰减? [英] TensorFlow: How to set learning rate decay based on epochs?
问题描述
学习率衰减函数tf.train.exponential_decay
接受一个decay_steps
参数.要每 num_epochs
降低学习率,您需要设置 decay_steps = num_epochs * num_train_examples/batch_size
.但是,当从 .tfrecords
文件中读取数据时,您不知道其中有多少训练示例.
The learning rate decay function tf.train.exponential_decay
takes a decay_steps
parameter. To decrease the learning rate every num_epochs
, you would set decay_steps = num_epochs * num_train_examples / batch_size
. However, when reading data from .tfrecords
files, you don't know how many training examples there are inside them.
要获得 num_train_examples
,您可以:
- 使用
num_epochs=1
设置tf.string_input_producer
. - 通过
tf.TFRecordReader
/tf.parse_single_example
运行它. - 循环并计算它在停止前产生输出的次数.
- Set up a
tf.string_input_producer
withnum_epochs=1
. - Run this through
tf.TFRecordReader
/tf.parse_single_example
. - Loop and count how many times it produces some output before stopping.
然而,这不是很优雅.
是否有更简单的方法可以从 .tfrecords
文件中获取训练示例的数量,或者基于 epochs 而不是步骤设置学习率衰减?
Is there an easier way to either get the number of training examples from a .tfrecords
file or set the learning rate decay based on epochs instead of steps?
推荐答案
您可以使用以下代码获取.tfrecords
文件中的记录数:
You can use the following code to get the number of records in a .tfrecords
file :
def get_num_records(tf_record_file):
return len([x for x in tf.python_io.tf_record_iterator(tf_record_file)])
这篇关于TensorFlow:如何根据纪元设置学习率衰减?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!