了解 PyTorch LSTM 的输入形状 [英] Understanding input shape to PyTorch LSTM

查看:27
本文介绍了了解 PyTorch LSTM 的输入形状的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这似乎是 PyTorch 中 LSTM 最常见的问题之一,但我仍然无法弄清楚 PyTorch LSTM 的输入形状应该是什么.

This seems to be one of the most common questions about LSTMs in PyTorch, but I am still unable to figure out what should be the input shape to PyTorch LSTM.

即使关注了几个帖子(123) 并尝试解决方案,它似乎不起作用.

Even after following several posts (1, 2, 3) and trying out the solutions, it doesn't seem to work.

背景:我已经编码了一批大小为 12 的文本序列(可变长度),并且使用 pad_packed_sequence 功能填充和打包序列.每个序列的 MAX_LEN 为 384,序列中的每个标记(或单词)的维度为 768.因此,我的批处理张量可能具有以下形状之一:[12, 384, 768][384, 12, 768].

Background: I have encoded text sequences (variable length) in a batch of size 12 and the sequences are padded and packed using pad_packed_sequence functionality. MAX_LEN for each sequence is 384 and each token (or word) in the sequence has a dimension of 768. Hence my batch tensor could have one of the following shapes: [12, 384, 768] or [384, 12, 768].

批处理将是我对 PyTorch rnn 模块(此处为 lstm)的输入.

The batch will be my input to the PyTorch rnn module (lstm here).

根据LSTM,它的输入维度是 (seq_len, batch, input_size) 我理解如下.
seq_len - 每个输入流中的时间步数(特征向量长度).
batch - 每批输入序列的大小.
input_size - 每个输入标记或时间步长的维度.

According to the PyTorch documentation for LSTMs, its input dimensions are (seq_len, batch, input_size) which I understand as following.
seq_len - the number of time steps in each input stream (feature vector length).
batch - the size of each batch of input sequences.
input_size - the dimension for each input token or time step.

lstm = nn.LSTM(input_size=?, hidden_​​size=?, batch_first=True)

此处的 input_sizehidden_​​size 确切值应该是多少?

What should be the exact input_size and hidden_size values here?

推荐答案

您已经解释了输入的结构,但是您还没有在输入维度和 LSTM 的预期输入维度之间建立联系.

You have explained the structure of your input, but you haven't made the connection between your input dimensions and the LSTM's expected input dimensions.

让我们分解您的输入(为维度分配名称):

Let's break down your input (assigning names to the dimensions):

  • batch_size:12
  • seq_len:384
  • input_size/num_features: 768
  • batch_size: 12
  • seq_len: 384
  • input_size / num_features: 768

这意味着 LSTM 的 input_size 需要是 768.

That means the input_size of the LSTM needs to be 768.

hidden_​​size 不依赖于你的输入,而是 LSTM 应该创建多少特征,然后用于隐藏状态和输出,因为这是最后一个隐藏状态.您必须决定要为 LSTM 使用多少个特征.

The hidden_size is not dependent on your input, but rather how many features the LSTM should create, which is then used for the hidden state as well as the output, since that is the last hidden state. You have to decide how many features you want to use for the LSTM.

最后,对于输入形状,设置 batch_first=True 要求输入具有形状 [batch_size, seq_len, input_size],在您的情况下,将 <代码>[12, 384, 768].

Finally, for the input shape, setting batch_first=True requires the input to have the shape [batch_size, seq_len, input_size], in your case that would be [12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])

这篇关于了解 PyTorch LSTM 的输入形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆