仅在数据指定特定条件时才从pytorch中的Custom Data-Loader加载数据 [英] Loading data from Custom Data-Loader in pytorch only if the data specifies a certain condition

查看:44
本文介绍了仅在数据指定特定条件时才从pytorch中的Custom Data-Loader加载数据的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在第一列中有一个CSV文件,文件名,在第二列中有一个文件名标签.我还有第三列,它指定有关数据的内容(数据是否满足特定条件).看起来像这样,

I have a CSV file with filename in the first column and a label for the filename in the second column. I also have a third column, which specifies something about the data (whether the data meets a specific condition). It will look something like,

+-----------------------------+
| Filepath 1   Label 1    'n' |
|                             |
+-----------------------------+
| Filepath 2   Label 2    'n' |
|                             |
|                             |
+-----------------------------+
| Filepath 3   Label 3     'n'|
|                             |
+-----------------------------+
| Filepath 4   Label 4     'y'|
+------------------------------+

仅当属性列=='y'时,我希望能够使用 getitem 加载自定义数据集.但是,出现以下错误:

I want to be able to load the custom dataset using getitem only when attribute column == 'y'. However, I get the following error:

TypeError:default_collat​​e:批处理必须包含张量,numpy数组,数字,字典或列表;找到< class'NoneType'>

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>

我的代码如下:

'''

class InterDataset(Dataset):
  def __init__(self, csv_file, mode, root_dir = None, transform = None, run = None):
    self.annotations = pd.read_csv(csv_file, header = None)
    self.root_dir = root_dir
    self.transform = transform
    self.mode = mode
    self.run = run

  def __len__(self):
    return len(self.annotations)

  def __getitem__(self, index):
    if self.mode == 'train':
        if (self.annotations.iloc[index, 2] == 'n'):
                    img_path = self.annotations.iloc[index,0]
                    image = cv2.imread(img_path,1)
                    
        
                    y_label = self.annotations.iloc[index,1]

                    if self.transform:
                        image = self.transform(image)
                    if (index+1)%300 == 0:
                        print('Loop {0} done'.format(index))
                    return [image, y_label]

    
            

'''

推荐答案

您会收到此错误,因为数据加载器必须返回某些内容.这是三种解决方案:

You get that error because the dataloader has to return something. Here are three solutions:

  1. 有一个名为 nonechucks 的库,您可以在其中创建可以跳过样本的数据加载器./li>
  2. 通常,您可以预处理/清除数据并将不想要的样品踢出去.
  3. 例如,您可以返回一些指示该样本是不需要的指示符
  1. There is a libary called nonechucks which lets you create dataloaders in which you can skip samples.
  2. Usually you could preprocess/clean your data and kick the unwanted samples out.
  3. You could return some indicator that the sample is unwanted, for example

if "y":
    return data, target
else:
    return -1

然后,您可以在火车循环中检查数据"是否为是-1并跳过迭代.我希望这会有所帮助:)

And then you could check in your train loop if the "data" is -1 and skip the iteration. I hope this was helpful :)

这篇关于仅在数据指定特定条件时才从pytorch中的Custom Data-Loader加载数据的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆