TensorFlow急切模式:如何从检查点还原模型? [英] TensorFlow Eager Mode: How to restore a model from a checkpoint?

查看:156
本文介绍了TensorFlow急切模式:如何从检查点还原模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经在TensorFlow渴望模式下训练了CNN模型。现在,我正在尝试从检查点文件中还原训练有素的模型,但没有成功。



我找到的所有示例(如下所示)都是谈论恢复会话的检查点。但是我需要的是将模型还原到热切的模式,即不创建会话。

 ,其中tf.Session()为sess:
#从磁盘恢复变量。
saver.restore(sess, /tmp/model.ckpt)

基本上我需要的是这样的东西:

  tfe.enable_eager_execution()
model = tfe.restore('model.ckpt' )
model.predict(...)

然后我可以使用模型做出预测。



有人可以帮忙吗?



更新



示例代码可在以下位置找到: mnist急切模式演示



我尝试按照@Jay Shah的回答进行操作并且它几乎可以正常工作,但是恢复的模型中没有任何变量。

  tfe.save_network_checkpoint(model,'。/ test / my_model.ckpt')

Out [58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_ne twork_checkpoint(model2,'。/ test / my_model.ckpt-1720')
model2.variables

Out [72]:
[]

原始模型中包含很多变量。

  model.variables 

[< tf。变量'mnist_model_1 / conv2d / kernel:0'shape =(5,5,1,32)dtype = float32,numpy =
array([[[[-8.25184360e-02,6.77833706e-03,6.97569922e-02,...


解决方案

好吧,花了几个小时以逐行模式运行代码后,我想出了一种将检查点还原到新TensorFlow的方法急切模式模型。



使用 TF急切模式MNIST



步骤:


  1. 训练完模型后,从训练过程中创建的检查点文件夹中找到最新的检查点(或所需的检查点)索引文件,例如'ckpt-25800。指数'。在第5步中还原时,仅使用文件名'ckpt-25800'。


  2. 启动新的python终端并通过运行以下命令启用TensorFlow Eager模式:



    tfe.enable_eager_execution()


  3. 创建新MNISTMOdel的实例:



    model_new = MNISTModel()


  4. 通过运行一次虚拟训练过程来初始化model_new的变量。(这一步骤很重要。如果不先初始化变量,则无法通过后续步骤还原它们。但是,我找不到其他方法在Eager模式下初始化变量,而不是我在下面做的事情。)



    model_new(tfe.Variable(np.zeros((1,784),dtype = np.float32)),训练= True)


  5. 使用步骤1中确定的检查点将变量恢复到model_new。 / p>

    tfe.Saver((model_new.variables))。restore('./ tf_checkpoints / ckpt-25800')


  6. 如果还原过程为su在成功时,您应该看到类似以下内容的



    INFO:tensorflow:从./tf_checkpoints/ckpt-25800恢复参数


现在,检查点已成功恢复到model_new,您可以使用它对新数据进行预测。


I've trained a CNN model in TensorFlow eager mode. Now I'm trying to restore the trained model from a checkpoint file but haven't got any success.

All the examples (as shown below) I've found are talking about restoring checkpoint to a Session. But what I need is to restore the model into eager mode, i.e. without creating a session.

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Basically what I need is something like:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)

and then I can use the model to make predictions.

Can someone please help?

Update

The example code can be found at: mnist eager mode demo

I've tried to follow the steps from @Jay Shah 's answer and it almost worked but the restored model doesn't have any variables in it.

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]

The original model has lots of variables in it.:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...

解决方案

Ok, after spending a few hours running the code in line-by-line mode, I've figured out a way to restore a checkpoint to a new TensorFlow Eager Mode model.

Using the examples from TF Eager Mode MNIST

Steps:

  1. After your model has been trained, find the latest checkpoint(or the checkpoint you want) index file from the checkpoint folder created in the training process, such as 'ckpt-25800.index'. Use only the filename 'ckpt-25800' while restoring in step 5.

  2. Start a new python terminal and enable TensorFlow Eager mode by running:

    tfe.enable_eager_execution()

  3. Create a new instance of the MNISTMOdel:

    model_new = MNISTModel()

  4. Initialise the variables for model_new by running a dummy train process once.(This step is important. Without initialising the variables first, they can't be restored by the following step. However I can't find another way to initialise variables in Eager mode other than what I did below.)

    model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)

  5. Restore the variables to model_new using the checkpoint identified in step 1.

    tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')

  6. If restore process is successful, you should see something like:

    INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800

Now the checkpoint has been successfully restored to model_new and you can use it to make predictions on new data.

这篇关于TensorFlow急切模式:如何从检查点还原模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆