初始化LSTM隐藏状态Tensorflow/Keras [英] Initializing LSTM hidden state Tensorflow/Keras

查看:308
本文介绍了初始化LSTM隐藏状态Tensorflow/Keras的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

有人可以解释如何在tensorflow中初始化LSTM的隐藏状态吗?我正在尝试构建LSTM循环自动编码器,因此在对模型进行训练之后,我想将学习到的非监督模型的隐藏状态转换为监督模型的隐藏状态. 使用当前的API甚至可能吗? 这是我要重新创建的纸张:

Can someone explain how can I initialize hidden state of LSTM in tensorflow? I am trying to build LSTM recurrent auto-encoder, so after i have that model trained i want to transfer learned hidden state of unsupervised model to hidden state of supervised model. Is that even possible with current API? This is paper I am trying to recreate:

http://papers.nips.cc/paper/5949-半监督序列学习.pdf

推荐答案

是的-这是可能的,但确实很麻烦.让我们来看一个例子.

Yes - this is possible but truly cumbersome. Let's go through an example.

  1. 定义模型:

  1. Defining a model:

from keras.layers import LSTM, Input
from keras.models import Model

input = Input(batch_shape=(32, 10, 1))
lstm_layer = LSTM(10, stateful=True)(input)

model = Model(input, lstm_layer)
model.compile(optimizer="adam", loss="mse")

首先构建和编译模型很重要,因为在编译过程中会重置初始状态.此外,您需要指定一个batch_shape,其中指定了batch_size,因为在这种情况下,我们的网络应为stateful(通过设置stateful=True模式来完成.

It's important to build and compile model first as in compilation the initial states are reset. Moreover - you need to specify a batch_shape where batch_size is specified as in this scenario our network should be stateful (which is done by setting a stateful=True mode.

现在我们可以设置初始状态的值:

Now we could set the values of initial states:

import numpy
import keras.backend as K

hidden_states = K.variable(value=numpy.random.normal(size=(32, 10)))
cell_states = K.variable(value=numpy.random.normal(size=(32, 10)))

model.layers[1].states[0] = hidden_states
model.layers[1].states[1] = cell_states 

请注意,您需要提供状态作为keras变量. states[0]保存隐藏状态,states[1]保存单元格状态.

Note that you need to provide states as a keras variables. states[0] holds hidden states and states[1] holds cell states.

希望有帮助.

这篇关于初始化LSTM隐藏状态Tensorflow/Keras的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆