TF LSTM:从培训课程中保存状态以供以后的预测课程 [英] TF LSTM: Save State from training session for prediction session later

查看:77
本文介绍了TF LSTM:从培训课程中保存状态以供以后的预测课程的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试从培训中保存最新的LSTM状态,以便在以后的预测阶段中重新使用.我遇到的问题是,在TF LSTM模型中,状态是通过一个占位符和一个numpy数组的组合从一个训练迭代传递到下一个的-在会话中,默认情况下似乎都不包含在Graph中已保存.

要解决此问题,我正在创建一个专用的TF变量来保存该状态的最新版本,以便将其添加到会话图中,如下所示:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

这似乎可以很好地将savedState变量添加到已保存的会话图中,并且稍后可以在会话的其余部分轻松恢复.

问题是,我设法在还原的会话中稍后实际使用该变量的唯一方法是,如果我在恢复会话后初始化会话中的所有变量(这似乎会重置所有训练后的变量,包括重量/偏见/等等!).如果我先初始化变量,然后恢复会话(这对于保留训练有素的varialbes效果很好),那么我将收到一个错误消息,我试图访问未初始化的变量.

我知道有一种方法可以初始化特定的单个变量(我在最初保存它时正在使用),但是问题是,当我们恢复它们时,我们将它们命名为字符串,而不仅仅是传递变量本身?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

完成这项工作的正确方法是什么?作为一种解决方法,我目前将State作为一个numpy数组保存到CSV,然后以相同的方式恢复它.它可以正常工作,但是考虑到保存/恢复TF会话的所有其他方面都可以正常工作,因此显然不是最干净的解决方案.

任何建议表示赞赏!

** 如下面接受的答案所述,这是运行良好的代码:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')


# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

我还没有测试过这种方法是否在存储的Session中无意中初始化了其他任何变量,但是由于我们只运行特定的变量,所以不知道为什么会这样.

解决方案

问题在于,在构造Saver之后创建新的tf.Variable意味着Saver不了解新变量.它仍然保存在元图中,但没有保存在检查点中:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

我已经用Saver知道的变量注释了您的问题的快速再现.

现在,解决方案相对容易.我建议在Saver之前创建Variable,然后使用 tf.assign 更新其值(确保您运行tf.assign返回的操作).分配的值将保存在检查点中,并像其他变量一样恢复.

作为None的特殊情况,当None传递给其var_list构造函数参数时(即,它可以自动获取新变量),可以更好地处理它.随时在Github上打开功能请求.

I am trying to save the latest LSTM State from training to be reused during the prediction stage later. The problem I am encountering is that in the TF LSTM model the State is passed around from one training iteration to next via a combination of a placeholder and a numpy array -- neither of which seems to be included in the Graph by default when the session is saved.

To work around this, I am creating a dedicated TF variable to hold the latest version of the state so as to add it to the Session graph, like so:

# latest State from last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now add to TF variable:
savedState = tf.Variable(ostate, dtype=tf.float32, name='savedState')
tf.variables_initializer([savedState]).run()
save_path = saver.save(sess, pathModel + '/my_model.ckpt')

This seems to add the savedState variable to the saved session graph well, and is easily recoverable later with the rest of the Session.

The problem though, is that the only way I have managed to actually use that variable later in the restored Session, is that if I initialize all variables in the session AFTER I recover it (which seems to reset all trained variables, including the weights/biases/etc.!). If I initialize variables first and THEN recover the session (which works fine in terms of preserving the trained varialbes), then I am getting an error that I'm trying to access an uninitialized variable.

I know there is a way to initialize a specific individual varialbe (which i am using while saving it originally) but the problem is that when we recover them, we refer to them by name as strings, we don't just pass the variable itself?!

# This produces an error 'trying to use an uninitialized varialbe
gInit = tf.global_variables_initializer().run()
new_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
new_saver.restore(sess, pathModel + 'my_model.ckpt')
fullState = sess.run('savedState:0')

What is the right way to get this done? As a workaround, I am currently saving the State to CSV just as a numpy array and then recover it the same way. It works OK, but clearly not the cleanest solution given that every other aspect of saving/restoring the TF session works perfectly.

Any suggestions appreciated!

**EDIT: Here's the code that works well, as described in the accepted answer below:

# make sure to define the State variable before the Saver variable:
savedState = tf.get_variable('savedState', shape=[BATCHSIZE, CELL_SIZE * LAYERS])
saver = tf.train.Saver(max_to_keep=1)
# last training iteration:
_, y, ostate, smm = sess.run([train_step, Y, H, summaries], feed_dict=feed_dict)
# now save the State and the whole model:
assignOp = tf.assign(savedState, ostate)
sess.run(assignOp)
save_path = saver.save(sess, pathModel + '/my_model.ckpt')


# later on, in some other program, recover the model and the State:
# make sure to initialize all variables BEFORE recovering the model!
gInit = tf.global_variables_initializer().run()
local_saver = tf.train.import_meta_graph(pathModel + 'my_model.ckpt.meta')
local_saver.restore(sess, pathModel + 'my_model.ckpt')
# recover the state from training and get its last dimension
fullState = sess.run('savedState:0')
h = fullState[-1]
h = np.reshape(h, [1, -1])

I haven't tested yet whether this approach unintentionally initializes any other variables in the saved Session, but don't see why it should, since we only run the specific one.

解决方案

The issue is that creating a new tf.Variable after the Saver was constructed means that the Saver has no knowledge of the new variable. It still gets saved in the metagraph, but not saved in the checkpoint:

import tensorflow as tf
with tf.Graph().as_default():
  var_a = tf.get_variable("a", shape=[])
  saver = tf.train.Saver()
  var_b = tf.get_variable("b", shape=[])
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  initializer = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([initializer])
    saver.save(session, "/tmp/model", global_step=0)
with tf.Graph().as_default():
  new_saver = tf.train.import_meta_graph("/tmp/model-0.meta")
  print(saver._var_list) # [<tf.Variable 'a:0' shape=() dtype=float32_ref>]
  with tf.Session() as session:
    new_saver.restore(session, "/tmp/model-0") # Only var_a gets restored!

I've annotated the quick reproduction of your issue above with the variables that the Saver knows about.

Now, the solution is relatively easy. I would suggest creating the Variable before the Saver, then using tf.assign to update its value (make sure you run the op returned by tf.assign). The assigned value will be saved in checkpoints and restored just like other variables.

This could be handled better by the Saver as a special case when None is passed to its var_list constructor argument (i.e. it could pick up new variables automatically). Feel free to open a feature request on Github for this.

这篇关于TF LSTM:从培训课程中保存状态以供以后的预测课程的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆