您如何使用动态“zero_state"创建一个 dynamic_rnn?(推理失败) [英] How do you create a dynamic_rnn with dynamic "zero_state" (Fails with Inference)

查看:30
本文介绍了您如何使用动态“zero_state"创建一个 dynamic_rnn?(推理失败)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在使用dynamic_rnn"来创建模型.

I have been working with the "dynamic_rnn" to create a model.

该模型基于一个 80 时间周期的信号,我想在每次运行之前将initial_state"归零,因此我设置了以下代码片段来完成此操作:

The model is based upon a 80 time period signal, and I want to zero the "initial_state" before each run so I have setup the following code fragment to accomplish this:

state = cell_L1.zero_state(self.BatchSize,Xinputs.dtype)
outputs, outState = rnn.dynamic_rnn(cell_L1,Xinputs,initial_state=state,  dtype=tf.float32)

这对训练过程非常有用.问题是,一旦我进行推理,其中我的 BatchSize = 1,我就会收到错误消息,因为 rnn状态"与新的 Xinputs 形状不匹配.所以我想的是我需要根据输入批量大小而不是硬编码来制作self.BatchSize".我尝试了许多不同的方法,但都没有奏效.我宁愿不通过 feed_dict 传递一堆零,因为它是基于批量大小的常量.

This works great for the training process. The problem is once I go to the inference, where my BatchSize = 1, I get an error back as the rnn "state" doesn't match the new Xinputs shape. So what I figured is I need to make "self.BatchSize" based upon the input batch size rather than hard code it. I tried many different approaches, and none of them have worked. I would rather not pass a bunch of zeros through the feed_dict as it is a constant based upon the batch size.

这是我的一些尝试.它们通常都会失败,因为在构建图形时输入大小是未知的:

Here are some of my attempts. They all generally fail since the input size is unknown upon building the graph:

state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)

.....

state = tf.zeros([Xinputs.get_shape()[0], self.state_size], Xinputs.dtype, name="RnnInitializer")

另一种方法,认为初始化器可能在运行时才被调用,但在图形构建时仍然失败:

Another approach, thinking the initializer might not get called until run-time, but still failed at graph build:

init = lambda shape, dtype: np.zeros(*shape)
state = tf.get_variable("state", shape=[Xinputs.get_shape()[0], self.state_size],initializer=init)

有没有办法动态创建这个恒定的初始状态,或者我是否需要使用张量服务代码通过 feed_dict 重置它?有没有一种聪明的方法可以在图表中使用 tf.Variable.assign 只执行一次?

Is there a way to get this constant initial state to be created dynamically or do I need to reset it through the feed_dict with tensor-serving code? Is there a clever way to do this only once within the graph maybe with an tf.Variable.assign?

推荐答案

问题的解决方案是如何获取batch_size"使得变量不是硬编码的.

The solution to the problem was how to obtain the "batch_size" such that the variable is not hard coded.

这是给定示例中的正确方法:

This was the correct approach from the given example:

Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
state = cell_L1.zero_state(Xinputs.get_shape()[0],Xinputs.dtype)

问题在于get_shape()[0]"的使用,它返回张量的形状"并在[0]处获取batch_size值.文档似乎不太清楚,但这似乎是一个常量值,因此当您将图形加载到推理中时,该值仍然是硬编码的(可能仅在创建图形时进行评估?).

The problem is the use of "get_shape()[0]", this returns the "shape" of the tensor and takes the batch_size value at [0]. The documentation doesn't seem to be that clear, but this appears to be a constant value so when you load the graph into an inference, this value is still hard coded (maybe only evaluated at graph creation?).

使用tf.shape()"函数似乎可以解决问题.这不会返回形状,而是一个张量.所以这似乎在运行时更新得更多.使用此代码片段解决了训练批次为 128 的问题,然后将图加载到 TensorFlow-Service 推理中,处理批次仅为 1.

Using the "tf.shape()" function, seems to do the trick. This doesn't return the shape, but a tensor. So this seems to be updated more at run-time. Using this code fragment solved the problem of a training batch of 128 and then loading the graph into TensorFlow-Service inference handling a batch of just 1.

Xinputs = tf.placeholder(tf.int32, (None, self.sequence_size, self.num_params), name="input")
batch_size = tf.shape(Xinputs)[0]
state = self.cell_L1.zero_state(batch_size,Xinputs.dtype)

这是 TensorFlow 常见问题解答的一个很好的链接,它描述了这种方法如何构建适用于可变批量大小的图表?":https://www.tensorflow.org/resources/faq

Here is a good link to TensorFlow FAQ which describes this approach 'How do I build a graph that works with variable batch sizes?': https://www.tensorflow.org/resources/faq

这篇关于您如何使用动态“zero_state"创建一个 dynamic_rnn?(推理失败)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆