如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建变量? [英] How to set the variables of LSTMCell as input instead of letting it create it in Tensorflow?

查看:253
本文介绍了如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建变量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我创建tf.contrib.rnn.LSTMCell时,它将在初始化期间创建其内核 bias 可训练变量.

When I create a tf.contrib.rnn.LSTMCell, it creates its kernel and bias trainable variables during initialisation.

代码现在的外观:

cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
                        state_is_tuple=True)

我希望它看起来像什么

kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
                        state_is_tuple=True)

我想做的是自己创建这些变量,并在实例化它作为init的输入时将其提供给LSTMCell类.

What I want to do is to create those variables myself, and give it to the LSTMCell class when instantiating it as input to its init.

有没有简单的方法可以做到这一点?我看了类源代码,但似乎在复杂的类层次结构中.

Is there an easy way to do this? I looked at the class source code but it seems that it is within a complex hierarchy of classes.

推荐答案

我继承了LSTMCell类,并更改了它的 init build 方法,以便它们接受给定的变量.如果变量在init中给出 在内部版本中,我们将不再使用 get_variable ,而将使用给定的内核变量和偏差变量.

I subclassed the LSTMCell class, and changed its init and build methods so that they accept given variables. If variables are given in init within build, we wouldn't use get_variable anymore, and would use the given kernel and bias variables.

虽然可能会有更清洁的方法.

There might be cleaner ways to do it though.

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

class MyLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, num_units,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=None, num_proj_shards=None,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):

        super(MyLSTMCell, self).__init__(num_units,
                 use_peepholes=use_peepholes, cell_clip=cell_clip,
                 initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
                 num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
                 forget_bias=forget_bias, state_is_tuple=state_is_tuple,
                 activation=activation, reuse=reuse, name=name)

        self.var_given = var_given
        if self.var_given:
            self._kernel = kernel
            self._bias = bias


    def build(self, inputs_shape):
        if inputs_shape[1].value is None:
            raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                             % inputs_shape)

        input_depth = inputs_shape[1].value
        h_depth = self._num_units if self._num_proj is None else self._num_proj
        maybe_partitioner = (
            partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
            if self._num_unit_shards is not None
            else None)
        if self.var_given:
            # self._kernel and self._bais are already added in init
            pass
        else:
            self._kernel = self.add_variable(
                _WEIGHTS_VARIABLE_NAME,
                shape=[input_depth + h_depth, 4 * self._num_units],
                initializer=self._initializer,
                partitioner=maybe_partitioner)
            self._bias = self.add_variable(
                _BIAS_VARIABLE_NAME,
                shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer(dtype=self.dtype))
        if self._use_peepholes:
            self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                               initializer=self._initializer)

        if self._num_proj is not None:
            maybe_proj_partitioner = (
                partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
                if self._num_proj_shards is not None
                else None)
            self._proj_kernel = self.add_variable(
                "projection/%s" % _WEIGHTS_VARIABLE_NAME,
                shape=[self._num_units, self._num_proj],
                initializer=self._initializer,
                partitioner=maybe_proj_partitioner)

        self.built = True

因此代码将如下所示:

kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)

这篇关于如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆