了解PyTorch LSTM的输入形状 [英] Understanding input shape to PyTorch LSTM

查看:100
本文介绍了了解PyTorch LSTM的输入形状的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这似乎是有关PyTorch中LSTM的最常见问题之一,但我仍然无法弄清楚PyTorch LSTM的输入形状是什么.

This seems to be one of the most common questions about LSTMs in PyTorch, but I am still unable to figure out what should be the input shape to PyTorch LSTM.

即使关注了几篇文章( 1 2

Even after following several posts (1, 2, 3) and trying out the solutions, it doesn't seem to work.

背景:我已经编码了一批大小为 12 的文本序列(可变长度),并且使用 pad_packed_sequence 功能填充和打包序列.每个序列的 MAX_LEN 为384,序列中的每个令牌(或单词)的尺寸为768.因此,我的批量张量可以具有以下形状之一: [12,384,768] [384,12,768] .

Background: I have encoded text sequences (variable length) in a batch of size 12 and the sequences are padded and packed using pad_packed_sequence functionality. MAX_LEN for each sequence is 384 and each token (or word) in the sequence has a dimension of 768. Hence my batch tensor could have one of the following shapes: [12, 384, 768] or [384, 12, 768].

批处理将作为我输入PyTorch rnn模块(此处为lstm)的方式.

The batch will be my input to the PyTorch rnn module (lstm here).

根据LSTM,其输入尺寸为(seq_len,batch,input_size),据我了解如下.
seq_len -每个输入流中的时间步数(特征向量长度).
batch -每一批输入序列的大小.
input_size -每个输入令牌或时间步的尺寸.

According to the PyTorch documentation for LSTMs, its input dimensions are (seq_len, batch, input_size) which I understand as following.
seq_len - the number of time steps in each input stream (feature vector length).
batch - the size of each batch of input sequences.
input_size - the dimension for each input token or time step.

lstm = nn.LSTM(input_size = ?, hidden_​​size = ?, batch_first = True)

此处确切的 input_size hidden_​​size 值应该是什么?

What should be the exact input_size and hidden_size values here?

推荐答案

您已经解释了输入的结构,但尚未在输入尺寸和LSTM的预期输入尺寸之间建立联系.

You have explained the structure of your input, but you haven't made the connection between your input dimensions and the LSTM's expected input dimensions.

让我们细分您的输入(为尺寸分配名称):

Let's break down your input (assigning names to the dimensions):

  • 批处理大小:12
  • seq_len :384
  • input_size / num_features :768
  • batch_size: 12
  • seq_len: 384
  • input_size / num_features: 768

这意味着LSTM的 input_size 必须为768.

That means the input_size of the LSTM needs to be 768.

hidden_​​size 并不取决于您的输入,而是取决于LSTM应该创建多少个功能,然后将这些功能用于隐藏状态和输出,因为这是最后一个隐藏状态.您必须决定要为 LSTM 使用多少个特征.

The hidden_size is not dependent on your input, but rather how many features the LSTM should create, which is then used for the hidden state as well as the output, since that is the last hidden state. You have to decide how many features you want to use for the LSTM.

最后,对于输入形状,设置 batch_first = True 要求输入的形状为 [batch_size,seq_len,input_size] ,在您的情况下为<代码> [12、384、768] .

Finally, for the input shape, setting batch_first=True requires the input to have the shape [batch_size, seq_len, input_size], in your case that would be [12, 384, 768].

import torch
import torch.nn as nn

# Size: [batch_size, seq_len, input_size]
input = torch.randn(12, 384, 768)

lstm = nn.LSTM(input_size=768, hidden_size=512, batch_first=True)

output, _ = lstm(input)
output.size()  # => torch.Size([12, 384, 512])

这篇关于了解PyTorch LSTM的输入形状的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆