实现“无限循环"数据集PyTorch 中的数据加载器 [英] Implementing an “infinite loop” Dataset & DataLoader in PyTorch

查看:37
本文介绍了实现“无限循环"数据集PyTorch 中的数据加载器的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想实现一个无限循环的数据集 &数据加载器.这是我尝试过的:

I’d like to implement an infinite loop Dataset & DataLoader. Here’s what I tried:

class Infinite(Dataset):
    def __len__(self):
        return HPARAMS.batch_size
#         return 1<<30 # This causes huge memory usage.
    def __getitem__(self, idx):
        """Randomly generates one new example."""
        return sample_func_to_be_parallelized()

infinite_loader = DataLoader(
    dataset=Infinite(), 
    batch_size=HPARAMS.batch_size, 
    num_workers=16,
    worker_init_fn=lambda worker_id: np.random.seed(worker_id),  
)

while True:
    for idx, data in enumerate(infinite_loader):
        # forward + backward on "data"

如您所见,这里的主要挑战是 __len()__ 方法.如果我在那里放了足够大的数字,例如 1<<30,则症状是在火车循环的第一次迭代中内存使用量将跳至 10+GB.过了一会儿,工人可能因为 OOM 被杀了.

As you can see, the main challenge here is the __len()__ method. If I put a large enough number there, like 1<<30, the symptom is memory usage will JUMP TO 10+GB on the first iteration of train loop. After a while the workers are killed presumably due to OOM.

如果我在那里放一个小数字,比如 1 或 BATCH_SIZE,训练循环中的采样数据"将被定期复制.这不是我想要的,因为我希望生成新数据 &在每次迭代中进行训练.

If I put a small number there, like 1 or BATCH_SIZE, the sampled "data" in the train loop will be periodically duplicated. This is not what I want as I’d like new data to be generated & trained on at every iteration.

我猜测内存使用过多的罪魁祸首是堆栈中的某个地方,缓存了一堆东西.随便看看 Python 的一面,我无法确定在哪里.

I’m guessing the culprit of the excessive memory usage is somewhere in the stack, a bunch of things are cached. Upon a casual look at the Python side of things I can’t pinpoint where.

有人可以建议实现我想要的东西的最佳方式是什么吗?(使用 DataLoader 的并行加载,同时保证加载的每个批次都是全新的.)

Can someone advise what’s the best way to have what I want implemented? (Use DataLoader’s parallel loading, while simultaneously guaranteeing every batch loaded is entirely new.)

推荐答案

这似乎可以在不定期复制数据的情况下工作:

This seems to be working without periodically duplicating the data:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

BATCH_SIZE = 2

class Infinite(Dataset):

    def __len__(self):
        return BATCH_SIZE

    def __getitem__(self, idx):
        return torch.randint(0, 10, (3,))


data_loader = DataLoader(Infinite(), batch_size=BATCH_SIZE, num_workers=16)

batch_count = 0
while True:
    batch_count += 1
    print(f'Batch {batch_count}:')

    data = next(iter(data_loader))
    print(data)
    # forward + backward on "data"  

    if batch_count == 5:
        break

结果:

Batch 1:
tensor([[4, 7, 7],
        [0, 8, 0]])
Batch 2:
tensor([[6, 8, 6],
        [2, 6, 7]])
Batch 3:
tensor([[6, 6, 2],
        [8, 7, 0]])
Batch 4:
tensor([[9, 4, 8],
        [2, 4, 1]])
Batch 5:
tensor([[9, 6, 1],
        [2, 7, 5]])

所以我认为问题出在您的函数 sample_func_to_be_parallelized() 中.

So I think the problem is in your function sample_func_to_be_parallelized().

编辑:如果不是 torch.randint(0, 10, (3,)) 我使用 np.random.randint(10, size=3) in __getitem__(以sample_func_to_be_parallelized()为例),那么数据确实在每批重复.请参阅此问题.

Edit: If instead of torch.randint(0, 10, (3,)) I use np.random.randint(10, size=3) in __getitem__ (as an example of the sample_func_to_be_parallelized()), then the data is indeed duplicated at each batch. See this issue.

因此,如果您在 sample_func_to_be_parallelized() 的某处使用 numpy 的 RGN,那么解决方法是使用

So if you are using numpy's RGN somewhere in your sample_func_to_be_parallelized(), then the workaround is to use

worker_init_fn=lambda worker_id: np.random.seed(np.random.get_state()[1][0] + worker_id) 

并在每次调用 data = next(iter(data_loader)) 之前通过 np.random.seed() 重置种子.

and to reset the seed by np.random.seed() before each call of data = next(iter(data_loader)).

这篇关于实现“无限循环"数据集PyTorch 中的数据加载器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆