TensorFlow LSTM/GRU重置状态每个纪元一次,而不是每个新批次 [英] Tensorflow LSTM/GRU reset states once per epoch and not for each new batch

查看:0
本文介绍了TensorFlow LSTM/GRU重置状态每个纪元一次,而不是每个新批次的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我基于GRU训练以下模型,请注意,我将参数stateful=True传递给GRU构建器。

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True  
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    [defining here my training step]

我实例化我的模型

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

[编译并做事情]

EPOCHS纪元

for i in range(EPOCHS):
  model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS], verbose=1)
  model.reset_states()
此代码关于GRU状态的行为是什么:状态是针对每个新的数据批次进行更新,还是仅针对每个新时期进行更新?所需的行为是仅对每个新纪元进行重置。如果没有完成,如何实现?

编辑

TensorFlow为Models实现reset_states函数

  def reset_states(self):
    for layer in self.layers:
      if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False):
        layer.reset_states()
是否意味着(与文档的其他含义相反)只有在stateful=False的情况下才能重置状态?这是我从getattr(layer, 'stateful', False)上的条件推断的。

推荐答案

您可以尝试重置自定义Callback中的状态:

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()

model.fit(train_dataset, validation_data=validation_dataset, epochs=1, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

另请参阅post有关何时重置状态的信息。

这篇关于TensorFlow LSTM/GRU重置状态每个纪元一次,而不是每个新批次的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆