如何从 DataLoader 获取样本的文件名? [英] How to get the filename of a sample from a DataLoader?

查看:31
本文介绍了如何从 DataLoader 获取样本的文件名?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要用我训练的卷积神经网络的数据测试结果编写一个文件.数据包括语音数据收集.文件格式需要为文件名,预测",但我很难提取文件名.我这样加载数据:

I need to write a file with the result of the data test of a Convolutional Neural Network that I trained. The data include speech data collection. The file format needs to be "file name, prediction", but I am having a hard time to extract the file name. I load the data like this:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

我正在尝试按如下方式写入文件:

and I am trying to write to the file as follows:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '
'
        f.write(format)
f.close()

os.listdir(TESTH_DATA_PATH + "/all")[i] 的问题在于它与test_loader 的加载文件顺序不同步.我能做什么?

The problem with os.listdir(TESTH_DATA_PATH + "/all")[i] is that it is not synchronized with the loaded files order of test_loader. What can I do?

推荐答案

嗯,这取决于你的 Dataset 是如何实现的.例如,在 torchvision.datasets.MNIST(...) 的情况下,您不能仅仅因为没有单个样本的文件名(MNIST 样本是 以不同方式加载).

Well, it depends on how your Dataset is implemented. For instance, in the torchvision.datasets.MNIST(...) case, you cannot retrieve the filename simply because there is no such thing as the filename of a single sample (MNIST samples are loaded in a different way).

由于您没有展示您的 Dataset 实现,我将告诉您如何使用 torchvision.datasets.ImageFolder(...)(或任何 torchvision.datasetsFolder(..)):

As you did not show your Dataset implementation, I'll tell you how this could be done with the torchvision.datasets.ImageFolder(...) (or any torchvision.datasets.DatasetFolder(...)):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}
".format(sample_fname, predicted.item()))
f.close()

可以看到在__getitem__(self, index),特别是 这里.

如果您实现了自己的 Dataset(并且可能希望支持 shufflebatch_size > 1),那么我将返回sample_fname 调用 __getitem__(...) 并执行如下操作:

If you implemented your own Dataset (and perhaps would like to support shuffle and batch_size > 1), then I would return the sample_fname on the __getitem__(...) call and do something like this:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

这样你就不需要关心shuffle.如果 batch_size 大于 1,您需要更改循环的内容以获得更通用的内容,例如:

This way you wouldn't need to care about shuffle. And if the batch_size is greater than 1, you would need to change the content of the loop for something more generic, e.g.:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("
".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "
")
f.close()

这篇关于如何从 DataLoader 获取样本的文件名?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆