在 tensorflow 中使用保存的模型进行预测 [英] Using saved model for prediction in tensorflow

查看:80
本文介绍了在 tensorflow 中使用保存的模型进行预测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我用这段代码来恢复我的模型,但是我不知道恢复后如何预测,我可以使用哪个功能?我是 tensorflow 的初学者,我不知道将保存哪些参数或函数.

在元模型中:

sess = tf.Session()saver = tf.train.import_meta_graph("/home/MachineLearning/model.ckpt.meta")saver.restore(sess,tf.train.latest_checkpoint('./'))print("模型恢复成功")x_predict,y_predict= load_svmlight_file('/MachineLearning/to_predict.csv')x_predict = x_valid.toarray()sess.run([] ,feed_dict ) #我不知道如何使用预测功能

结果如下:

$python predict.py模型恢复成功回溯(最近一次调用最后一次): 中的文件predict.py",第 23 行sess.run([] ,feed_dict )NameError:未定义名称feed_dict"

解决方案

大功告成.Tensorflow 只是一个数学库.您的图形是具有关联依赖项的数学运算的集合(例如图形,特别是 DAG).

当您加载图表和相关变量(权重)时,您加载了所有定义.现在您需要让 tensorflow 计算图中的某个值.它可以计算很多值,您想要的值通常被命名为 logits(神经网络输出层的典型名称).但请注意,它可以命名为任何名称(特别是如果这不是神经网络模型),您需要了解该模型.您可能还想计算一个名为 accuracy 的操作,该操作被定义为计算特定批次输入的准确性(同样取决于您的模型).

请注意,您需要向 tensorflow 提供执行这些计算所需的任何内容.通常有一个 placeholder 您可以在其中传递数据(并且在为您的标签训练一个 placeholder 期间,您不需要进行预测,因为您不会询问任何操作tensorflow 计算依赖于它).

但是您需要获取对这些不同操作(logitsaccuracy)和占位符(x 是典型名称)的引用.由于您从磁盘加载图形,因此您没有引用(请注意,加载模型的另一种方法是重新运行构建模型的代码,这样您就可以轻松访问所需的引用).

为了获得正确的参考,您可以按名称查找.以下是获取所有操作列表的方法:

Tensorflow 图中的张量名称列表

然后通过名称获取特定的 OP(操作):

如何通过名称获取 tensorflow op?

所以你会有这样的东西:

logits = tf.get_default_graph().get_operation_by_name("logits:0")x = tf.get_default_graph().get_operation_by_name("x:0")准确度 = tf.get_default_graph().get_operation_by_name("accuracy:0")

请注意,:0 是为 tensorflow 中的所有名称添加的索引,以避免重复名称.现在你有了你需要的所有引用,你可以使用 sess.run 来执行特定的计算,提供输入数据和你想要计算的 OP:

sess.run([logits,accuracy], feed_dict={x:your_input_data_in_numpy_format})

这些元素的名称在您的实现中会有所不同,我使用了最常见的名称.如果它们没有漂亮的名字,就很难识别它们,您需要查看生成图表的原始代码.事实上,如果它们没有正确命名,按名称查找它们会非常痛苦,以至于重新运行生成原始图的代码而不是导入元图可能更好.请注意,saver.restore 仅恢复实际数据,import_meta_graph 是可选部分,可以通过简单地以编程方式重新构建图形来替换.

I use this code to restore my model, but I don't know how to predict after restoring it, which function can I use? I'm a beginner in tensorflow, I have no idea to which parameters or function will be saved.

In the meta model:

sess = tf.Session()
saver = tf.train.import_meta_graph("/home/MachineLearning/model.ckpt.meta")
saver.restore(sess,tf.train.latest_checkpoint('./'))
print("Model restored with success ")
x_predict,y_predict= load_svmlight_file('/MachineLearning/to_predict.csv')
x_predict = x_valid.toarray()
sess.run([] ,feed_dict ) #i don't know how to use predict function

These are the results:

$python predict.py
Model restored with success 
Traceback (most recent call last):
  File "predict.py", line 23, in <module>
    sess.run([] ,feed_dict )
NameError: name 'feed_dict' is not defined

解决方案

You're almost there. Tensorflow is simply a math library. Your graph is a collection of math operations with the associated dependencies (e.g. a graph, DAG specifically).

When you loaded the graph and associated variables (weights) you loaded all the definitions. Now you need to ask tensorflow to compute some value in the graph. There are lots of values it could compute, the one you want is often named logits (a typical name for the output layer of a neural network). But note that it could be named anything (especially if this isn't a neural network model), you need to understand the model. You might also want to compute an operation named accuracy which is defined to compute the accuracy of a particular batch of inputs (again depends on your model).

Note that you will need to provide tensorflow with whatever it needs to perform these computations. There is generally a placeholder where you pass in your data (and during training a placeholder for your labels which you don't need for prediction because none of the operations you will ask tensorflow to compute depend on it).

But you will need to get references to these various operations (logits, and accuracy) and placeholders (x is a typical name). Since you loaded your graph from disk you don't have the references (note that an alternative way of loading the model is to re-run the code that builds the model, which gives you easy access to the references you need).

In order to get the right references you can look them up by name. Here's how you would get a list of all the operations:

List of tensor names in graph in Tensorflow

Then to get a specific OP (operation) by name:

How to get a tensorflow op by name?

So you'll have something like this:

logits = tf.get_default_graph().get_operation_by_name("logits:0")
x = tf.get_default_graph().get_operation_by_name("x:0")
accuracy = tf.get_default_graph().get_operation_by_name("accuracy:0")

Note that the :0 is an index added to all names in tensorflow to avoid duplicate names. Now you have all the references you need and you can use sess.run to perform a specific computation, providing the input data, and OPs you'd like to have computed:

sess.run([logits, accuracy], feed_dict={x:your_input_data_in_numpy_format})

The names of these elements will vary in your implementation, I've used the most common names. If they weren't given pretty names it'll be hard to identify them and you'll need to look through the original code that produced the graph. In fact if they weren't named properly looking them up by name is so painful that it's probably better to just re-run the code that produced the original graph rather than import the meta graph. Notice that saver.restore only restores the actual data, import_meta_graph is the optional piece which can be replaced by simply re-building the graph programmatically.

这篇关于在 tensorflow 中使用保存的模型进行预测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆