TensorFlow - 动态输入批次大小? [英] TensorFlow - Dynamic Input Batch Size?

查看:35
本文介绍了TensorFlow - 动态输入批次大小?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

就我而言,我需要在训练期间动态更改 batch_size.例如,我需要每 10 个时期将 batch_size 加倍.然而,问题是,虽然我知道如何使其动态化,但在输入管道中,我必须确定批量大小,如下面的代码所示.也就是说,要使用 tf.train.shuffle_batch 我必须确定 batch_size 参数,之后我找不到任何更改它的方法.因此,我将不胜感激任何建议!如何制作动态输入批处理?

In my case, I need to dynamically change the batch_size during training. For example, I need to double the batch_size for every 10 epochs. However, the problem is that, although I know how to make it dynamic, in input pipeline I have to determine the batch size, as the following code shows. That is, to use tf.train.shuffle_batch I have to determine the batch_size argument, and I do not find any ways to change it afterward. Therefore, I would appreciate any suggestions! How do you make a dynamic input batch?

  filename_queue = tf.train.string_input_producer([self.tfrecords_file])
  reader = tf.TFRecordReader()

  _, serialized_example = self.reader.read(filename_queue)
  features = tf.parse_single_example(
      serialized_example,
      features={
        'image/file_name': tf.FixedLenFeature([], tf.string),
        'image/encoded_image': tf.FixedLenFeature([], tf.string),
      })

  image_buffer = features['image/encoded_image']
  image = tf.image.decode_jpeg(image_buffer, channels=3)
  image = self._preprocess(image)
  images = tf.train.shuffle_batch(
        [image], batch_size=self.batch_size, num_threads=self.num_threads,
        capacity=self.min_queue_examples + 3*self.batch_size,
        min_after_dequeue=self.min_queue_examples
      )

推荐答案

我相信你想要做的是以下(这个我没试过,如果有错误请纠正我).

I believe what you want to do is the following (I haven't tried this, so correct me if I make a mistake).

为您的批量大小创建一个占位符:

Create a placeholder for your batch size:

batch_size_placeholder = tf.placeholder(tf.int64)

使用占位符创建您的随机播放批次:

Create your shuffle batch using the placeholder:

images = tf.train.shuffle_batch(
    [image], batch_size=self.batch_size, num_threads=self.num_threads,
    capacity=self.min_queue_examples + 3*batch_size_placeholder,
    min_after_dequeue=self.min_queue_examples
  )

通过调用 sess.run 传入批量大小:

Pass in batch size with your call to sess.run:

sess.run(my_optimizer, feed_dict={batch_size_placeholder: my_dynamic_batch_size})

我希望 shuffle_batch 会接受张量.

I expect shuffle_batch will accept tensors.

如果有任何问题,您可以考虑使用数据集管道.这是进行数据流水线的更新、更新颖、更闪亮的方式.

If there's any issue with that you might consider using the Dataset pipeline. It is the newer, fancier, shinier way to do data pipelining.

https://www.tensorflow.org/programmers_guide/datasets

这篇关于TensorFlow - 动态输入批次大小?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆