什么是参数“max_q_size"?用于“model.fit_generator"? [英] What is the parameter "max_q_size" used for in "model.fit_generator"?

查看:34
本文介绍了什么是参数“max_q_size"?用于“model.fit_generator"?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我构建了一个简单的生成器,它生成一个 tuple(inputs, targets),其中只有 inputstargets 列表中的单个项目.基本上,它正在抓取数据集,一次一个样本项.

I built a simple generator that yields a tuple(inputs, targets) with only single items in the inputs and targets lists. Basically, it is crawling the data set, one sample item at a time.

我将此生成器传递给:

  model.fit_generator(my_generator(),
                      nb_epoch=10,
                      samples_per_epoch=1,
                      max_q_size=1  # defaults to 10
                      )

我明白了:

  • nb_epoch 是训练批次的运行次数
  • samples_per_epoch 是每个 epoch 训练的样本数
  • nb_epoch is the number of times the training batch will be run
  • samples_per_epoch is the number of samples trained with per epoch

但是 max_q_size 是什么?为什么默认为 10?我认为使用生成器的目的是将数据集批处理成合理的块,那么为什么要额外的队列呢?

But what is max_q_size for and why would it default to 10? I thought the purpose of using a generator was to batch data sets into reasonable chunks, so why the additional queue?

推荐答案

这只是定义了内部训练队列的最大大小,用于从生成器预缓存"您的样本.它在队列的生成过程中使用

This simply defines the maximum size of the internal training queue which is used to "precache" your samples from generator. It is used during generation of the the queues

def generator_queue(generator, max_q_size=10,
                    wait_time=0.05, nb_worker=1):
    '''Builds a threading queue out of a data generator.
    Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
    '''
    q = queue.Queue()
    _stop = threading.Event()

    def data_generator_task():
        while not _stop.is_set():
            try:
                if q.qsize() < max_q_size:
                    try:
                        generator_output = next(generator)
                    except ValueError:
                        continue
                    q.put(generator_output)
                else:
                    time.sleep(wait_time)
            except Exception:
                _stop.set()
                raise

    generator_threads = [threading.Thread(target=data_generator_task)
                         for _ in range(nb_worker)]

    for thread in generator_threads:
        thread.daemon = True
        thread.start()

    return q, _stop

换句话说,你有一个线程直接从你的生成器填充队列到给定的最大容量,而(例如)训练例程消耗它的元素(有时等待完成)

In other words you have a thread filling the queue up to given, maximum capacity directly from your generator, while (for example) training routine consumes its elements (and sometimes waits for the completion)

 while samples_seen < samples_per_epoch:
     generator_output = None
     while not _stop.is_set():
         if not data_gen_queue.empty():
             generator_output = data_gen_queue.get()
             break
         else:
             time.sleep(wait_time)

为什么默认为 10?没有特别的原因,就像大多数默认值一样 - 这很有意义,但您也可以使用不同的值.

and why default of 10? No particular reason, like most of the defaults - it simply makes sense, but you could use different values too.

像这样的结构表明,作者考虑了昂贵的数据生成器,这可能需要时间来执行.例如,考虑在生成器调用中通过网络下载数据 - 那么为了效率和对网络错误等的鲁棒性,预缓存一些下一批,并并行下载下一批是有意义的.

Construction like this suggests, that authors thought about expensive data generators, which might take time to execture. For example consider downloading data over a network in generator call - then it makes sense to precache some next batches, and download next ones in parallel for the sake of efficiency and to be robust to network errors etc.

这篇关于什么是参数“max_q_size"?用于“model.fit_generator"?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆