将类对象添加到 Pytorch Dataloader:批处理必须包含张量 [英] Adding class objects to Pytorch Dataloader: batch must contain tensors

查看:26
本文介绍了将类对象添加到 Pytorch Dataloader:批处理必须包含张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个自定义 Pytorch 数据集,它返回一个包含类对象查询"的字典.

I have a custom Pytorch dataset that returns a dictionary containing a class object "queries".

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return self.values.shape[0]

    def __getitem__(self, idx):
        sample = DeviceDict({'query': self.queries[idx],
                             "values": self.values[idx],
                             "targets": self.targets[idx]})
        return sample

问题是,当我将查询放入数据加载器时,我得到 default_collat​​e: batch must contain tensors, numpy arrays, numbers, dicts or lists;找到 .有没有办法在我的数据加载器中有一个类对象?它在下面代码中的 next(iterator) 处爆炸.

The problem is that when I put the queries in a data loader I get default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>. Is there a way to have a class object in my data loader? It blows up at next(iterator) in the code below.

train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
                                           batch_size=10],
                                           shuffle=True,
                                           drop_last=False)
for i in range(epochs):
    iterator = iter(train_loader)
    for i in range(len(train_loader)):
        batch = next(iterator)
        out = model(batch)
        loss = criterion(out["pred"], batch["targets"])
        self.optimizer.zero_grad()
        loss.sum().backward()
        self.optimizer.step()

推荐答案

您需要定义自己的 colate_fn 为了做到这一点.一种草率的方法只是为了向您展示这里的工作原理,可能是这样的:

You need to define your own colate_fn in order to do this. A sloppy approach just to show you how stuff works here, would be something like this:

import torch
class DeviceDict:
    def __init__(self, data):
        self.data = data 

    def print_data(self):
        print(self.data)

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return 5

    def __getitem__(self, idx):
        sample = {'query': self.queries[idx],
                 "values": self.values[idx],
                 "targets": self.targets[idx]}
        return sample

def custom_collate(dict):
    return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

基本上 colate_fn 允许您实现自定义批处理或添加对自定义数据类型的支持,如我之前提供的链接中所述.
如您所见,它只是展示了概念,您需要根据自己的需要对其进行更改.

Basically colate_fn allows you to achieve custom batching or adding support for custom data types as explained in the link I previously provided.
As you see it just shows the concept, you need to change it based on your own needs.

这篇关于将类对象添加到 Pytorch Dataloader:批处理必须包含张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆