`*** RuntimeError: mat1 dim 1 must match mat2 dim 0` 每当我运行模型(图像) [英] `*** RuntimeError: mat1 dim 1 must match mat2 dim 0` whenever I run model(images)

查看:82
本文介绍了`*** RuntimeError: mat1 dim 1 must match mat2 dim 0` 每当我运行模型(图像)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

    def __init__(self):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=5, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, kernel_size=3, stride=2, bias=False),
            nn.BatchNorm2d(64),
  
        )

我该如何处理这个错误?我认为错误出在 self.fc 上,但我不知道如何修复它.

How can I deal with this error? I think the error is with self.fc, but I can't say how to fix it.

推荐答案

self.conv(x) 的输出是 torch.Size([32, 64, 2,2]):32*64*2*2=8192(这相当于(self.conv_out_size).全连接层的输入需要一个单维向量,即在传递到前向函数中的全连接层之前需要将其展平.

The output from self.conv(x) is of shape torch.Size([32, 64, 2, 2]): 32*64*2*2= 8192 (this is equivalent to (self.conv_out_size). The input to fully connected layer expects a single dimension vector i.e. you need to flatten it before passing to a fully connected layer in the forward function.

class Network():
    ...
    def foward():
    ...
        conv_out = self.conv(x)
        print(conv_out.shape)
        conv_out = conv_out.view(-1, 32*64*2*2)
        print(conv_out.shape)
        x = self.fc(conv_out)
        return x

输出

torch.Size([32, 64, 2, 2])
torch.Size([1, 8192])


我认为您使用的 self._get_conv_out 函数是错误的.

I think you're using self._get_conv_out function wrong.

应该是

    def _get_conv_out(self, shape):
        output = self.conv(torch.zeros(1, *shape)) # not (32, *size)
        return int(numpy.prod(output.size()))

然后,在前传中,你可以使用

then, in the forward pass, you can use

        conv_out = self.conv(x)
        # flatten the output of conv layers
        conv_out = conv_out.view(conv_out.size(0), -1)
        x = self.fc(conv_out)

对于(32, 1, 110, 110)的输入,输出应该是torch.Size([32, 2]).

For an input of (32, 1, 110, 110), the output should be torch.Size([32, 2]).

这篇关于`*** RuntimeError: mat1 dim 1 must match mat2 dim 0` 每当我运行模型(图像)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆