TensorFlow - 动态输入批次大小? [英] TensorFlow - Dynamic Input Batch Size?
问题描述
就我而言,我需要在训练期间动态更改 batch_size.例如,我需要每 10 个时期将 batch_size 加倍.然而,问题是,虽然我知道如何使其动态化,但在输入管道中,我必须确定批量大小,如下面的代码所示.也就是说,要使用 tf.train.shuffle_batch
我必须确定 batch_size
参数,之后我找不到任何更改它的方法.因此,我将不胜感激任何建议!如何制作动态输入批处理?
In my case, I need to dynamically change the batch_size during training. For example, I need to double the batch_size for every 10 epochs. However, the problem is that, although I know how to make it dynamic, in input pipeline I have to determine the batch size, as the following code shows. That is, to use tf.train.shuffle_batch
I have to determine the batch_size
argument, and I do not find any ways to change it afterward. Therefore, I would appreciate any suggestions! How do you make a dynamic input batch?
filename_queue = tf.train.string_input_producer([self.tfrecords_file])
reader = tf.TFRecordReader()
_, serialized_example = self.reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image/file_name': tf.FixedLenFeature([], tf.string),
'image/encoded_image': tf.FixedLenFeature([], tf.string),
})
image_buffer = features['image/encoded_image']
image = tf.image.decode_jpeg(image_buffer, channels=3)
image = self._preprocess(image)
images = tf.train.shuffle_batch(
[image], batch_size=self.batch_size, num_threads=self.num_threads,
capacity=self.min_queue_examples + 3*self.batch_size,
min_after_dequeue=self.min_queue_examples
)
推荐答案
我相信你想要做的是以下(这个我没试过,如果有错误请纠正我).
I believe what you want to do is the following (I haven't tried this, so correct me if I make a mistake).
为您的批量大小创建一个占位符:
Create a placeholder for your batch size:
batch_size_placeholder = tf.placeholder(tf.int64)
使用占位符创建您的随机播放批次:
Create your shuffle batch using the placeholder:
images = tf.train.shuffle_batch(
[image], batch_size=self.batch_size, num_threads=self.num_threads,
capacity=self.min_queue_examples + 3*batch_size_placeholder,
min_after_dequeue=self.min_queue_examples
)
通过调用 sess.run
传入批量大小:
Pass in batch size with your call to sess.run
:
sess.run(my_optimizer, feed_dict={batch_size_placeholder: my_dynamic_batch_size})
我希望 shuffle_batch
会接受张量.
I expect shuffle_batch
will accept tensors.
如果有任何问题,您可以考虑使用数据集管道.这是进行数据流水线的更新、更新颖、更闪亮的方式.
If there's any issue with that you might consider using the Dataset pipeline. It is the newer, fancier, shinier way to do data pipelining.
https://www.tensorflow.org/programmers_guide/datasets
这篇关于TensorFlow - 动态输入批次大小?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!