将tensorflow模型保存到文件 [英] Save tensorflow model to file

查看:455
本文介绍了将tensorflow模型保存到文件的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我创建了一个tensorflow模型,我想将其保存到文件中,以便以后可以对其进行预测.特别是,我需要保存:

I create a tensorflow model which I would like to save to file so that I can predict against it later. In particular, I need to save the:

  • input_placeholder
    (= tf.placeholder(tf.float32, [None, iVariableLen]))
  • solution_space
    (= tf.nn.sigmoid(tf.matmul(input_placeholder, weight_variable) + bias_variable))
  • 会话
    (= tf.Session())
  • input_placeholder
    (= tf.placeholder(tf.float32, [None, iVariableLen]))
  • solution_space
    (= tf.nn.sigmoid(tf.matmul(input_placeholder, weight_variable) + bias_variable))
  • session
    (= tf.Session())

我试过使用可在其他对象(例如sklearn二进制化器等)上使用的pickle,但不能在上面的对象上使用,因此我在底部发现了错误.

I've tried using pickle which works on other objects like sklearn binarizers etc, but not on the above, for which I get the error at the bottom.

我如何腌制:

import pickle
with open(sModelSavePath, 'w') as fiModel:
    pickle.dump(dModel, fiModel)

其中dModel是一个字典,其中包含我要保留的所有对象,这些对象用于调整.

where dModel is a dictionary that contains all the objects I want to persist, which I use for fitting against.

关于如何腌制张量流对象的任何建议吗?

Any suggestions on how to pickle tensorflow objects?

错误消息:

pickle.dump(dModel, fiModel)
...
    raise TypeError, "can't pickle %s objects" % base.__name__
TypeError: can't pickle module objects

推荐答案

我解决此问题的方法是 pickleing Sklearn对象(例如二进制化器),并使用

The way I solved this was by pickleing Sklearn objects like binarizers, and using tensorflow's inbuilt save functions for the actual model:

保存张量流模型:
1)像往常一样建立模型
2)用tf.train.Saver()保存会话.例如:

Saving tensorflow model:
1) Build the model as you usually would
2) Save the session with tf.train.Saver(). For example:

oSaver = tf.train.Saver()

oSess = oSession
oSaver.save(oSess, sModelPath)  #filename ends with .ckpt

3)这会将该会话中的所有可用变量等保存到其变量名.

3) This saves all available variables etc in that session to their variable names.

加载张量流模型:
1)整个流程需要重新初始化.换句话说,需要声明变量,权重,偏差,损失函数等,然后通过将tf.initialize_all_variables()传递给oSession.run()
进行初始化 2)现在需要将该会话传递给加载程序.我对流进行了抽象,因此我的加载程序如下所示:

Loading tensorflow model:
1) The entire flow needs to be re-initialized. In other words, variables, weights, bias, loss function etc need to be declared, and then initialized with tf.initialize_all_variables() being passed into oSession.run()
2) That session now needs to be passed to the loader. I abstracted the flow, so my loader looks like this:

dAlg = tf_training_algorithm()  #defines variables etc and initializes session

oSaver = tf.train.Saver()
oSaver.restore(dAlg['oSess'], sModelPath)

return {
    'oSess': dAlg['oSess'],
    #the other stuff I need from my algorithm, like my solution space etc
}

3)您需要从预测中删除所有需要进行预测的对象,在我的情况下,这些对象位于dAlg

3) All objects you need for prediction need to be gotten out of your initialisation, which in my case sit in dAlg

PS:像这样的泡菜:

PS: Pickle like this:

with open(sSavePathFilename, 'w') as fiModel:
    pickle.dump(dModel, fiModel)

with open(sFilename, 'r') as fiModel:
    dModel = pickle.load(fiModel)

这篇关于将tensorflow模型保存到文件的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆