Pytorch验证模型错误:预期输入batch_size(3)匹配目标batch_size(4) [英] Pytorch Validating Model Error: Expected input batch_size (3) to match target batch_size (4)
本文介绍了Pytorch验证模型错误:预期输入batch_size(3)匹配目标batch_size(4)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在Pytorch中构建一个NN,应该对102个类别进行分类.
I'm building a NN in Pytorch that is supposed to classify across 102 classes.
我具有以下验证功能:
def validation(model, testloader, criterion):
test_loss = 0
accuracy = 0
for inputs, classes in testloader:
inputs = inputs.to('cuda')
output = model.forward(inputs)
test_loss += criterion(output, labels).item()
ps = torch.exp(output)
equality = (labels.data == ps.max(dim=1)[1])
accuracy += equality.type(torch.FloatTensor).mean()
return test_loss, accuracy
培训代码(通话validation
):
Code for training (calls validation
):
epochs = 3
print_every = 40
steps = 0
running_loss = 0
testloader = dataloaders['test']
# change to cuda
model.to('cuda')
for e in range(epochs):
running_loss = 0
for ii, (inputs, labels) in enumerate(dataloaders['train']):
steps += 1
inputs, labels = inputs.to('cuda'), labels.to('cuda')
optimizer.zero_grad()
# Forward and backward passes
outputs = model.forward(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if steps % print_every == 0:
model.eval()
with torch.no_grad():
test_loss, accuracy = validation(model, testloader, criterion)
print("Epoch: {}/{}.. ".format(e+1, epochs),
"Training Loss: {:.3f}.. ".format(running_loss/print_every),
"Test Loss: {:.3f}.. ".format(test_loss/len(testloader)),
"Test Accuracy: {:.3f}".format(accuracy/len(testloader)))
running_loss = 0
model.train()
我收到此错误消息:
ValueError: Expected input batch_size (3) to match target batch_size (4).
完整追溯:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-63-f9f67ed13b94> in <module>()
28 model.eval()
29 with torch.no_grad():
---> 30 test_loss, accuracy = validation(model, testloader, criterion)
31
32 print("Epoch: {}/{}.. ".format(e+1, epochs),
<ipython-input-62-dbc77acbda5e> in validation(model, testloader, criterion)
6 inputs = inputs.to('cuda')
7 output = model.forward(inputs)
----> 8 test_loss += criterion(output, labels).item()
9
10 ps = torch.exp(output)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
489 result = self._slow_forward(*input, **kwargs)
490 else:
--> 491 result = self.forward(*input, **kwargs)
492 for hook in self._forward_hooks.values():
493 hook_result = hook(self, input, result)
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/loss.py in forward(self, input, target)
191 _assert_no_grad(target)
192 return F.nll_loss(input, target, self.weight, self.size_average,
--> 193 self.ignore_index, self.reduce)
194
195
/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce)
1328 if input.size(0) != target.size(0):
1329 raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 1330 .format(input.size(0), target.size(0)))
1331 if dim == 2:
1332 return torch._C._nn.nll_loss(input, target, weight, size_average, ignore_index, reduce)
ValueError: Expected input batch_size (3) to match target batch_size (4).
我不知道错误是从哪里来的.确实,没有验证码,训练部分就可以完美地工作.
I don't understand where the error comes from. Indeed, without the validation code the training part works perfectly.
推荐答案
在您的验证功能中,
def validation(model, testloader, criterion):
test_loss = 0
accuracy = 0
for inputs, classes in testloader:
inputs = inputs.to('cuda')
output = model.forward(inputs)
test_loss += criterion(output, labels).item()
ps = torch.exp(output)
equality = (labels.data == ps.max(dim=1)[1])
accuracy += equality.type(torch.FloatTensor).mean()
return test_loss, accuracy
您正在迭代测试加载器,并将值传递给变量inputs, classes
,但是您要将labels
传递给您的条件.
You are iterating on testloader and passing values to the variables inputs, classes
but you are passing labels
to your criterion.
这篇关于Pytorch验证模型错误:预期输入batch_size(3)匹配目标batch_size(4)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文