PyTorch 运行时错误:断言 `cur_target &gt;= 0 &amp;&amp;cur_target <n_classes' 失败 [英] PyTorch RuntimeError: Assertion `cur_target &gt;= 0 &amp;&amp; cur_target &lt; n_classes&#39; failed

查看:34
本文介绍了PyTorch 运行时错误:断言 `cur_target &gt;= 0 &amp;&amp;cur_target <n_classes' 失败的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在 Pytorch 中创建一个基本的二元分类器,用于对我的玩家在 Pong 游戏中是在右侧还是左侧进行分类.输入是一个 1x42x42 的图像,标签是我播放器的一侧(右 = 1 或左 = 2).代码:

I’m trying to create a basic binary classifier in Pytorch that classifies whether my player plays on the right or the left side in the game Pong. The input is an 1x42x42 image and the label is my player's side (right = 1 or left = 2). The code:

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

net = Net(42 * 42, 100, 2)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_net = torch.optim.Adam(net.parameters(), 0.001)
net.train()

while True:
    state = get_game_img()
    state = torch.from_numpy(state)

    # right = 1, left = 2
    current_side = get_player_side()
    target = torch.LongTensor(current_side)
    x = Variable(state.view(-1, 42 * 42))
    y = Variable(target)
    optimizer_net.zero_grad()
    y_pred = net(x)
    loss = criterion(y_pred, y)
    loss.backward()
    optimizer.step()

我得到的错误:

  File "train.py", line 109, in train
    loss = criterion(y_pred, y)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 206, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/modules/loss.py", line 321, in forward
    self.weight, self.size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 533, in cross_entropy
    return nll_loss(log_softmax(input), target, weight, size_average)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/functional.py", line 501, in nll_loss
    return f(input, target)
  File "/home/shani/anaconda2/lib/python2.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forward
    output, *self.additional_args)
RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed.  at /py/conda-bld/pytorch_1493676237139/work/torch/lib/THNN/generic/ClassNLLCriterion.c:57

推荐答案

对于大多数深度学习库,目标(或标签)应该从 0 开始.

For most of deeplearning library, target(or label) should start from 0.

这意味着你的目标应该在 [0,n) 的范围内,有 n 个类.

It means that your target should be in the range of [0,n) with n-classes.

这篇关于PyTorch 运行时错误:断言 `cur_target &gt;= 0 &amp;&amp;cur_target <n_classes' 失败的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆