使用创建的 tensorflow 模型进行预测 [英] Using a created tensorflow model for predicting

查看:69
本文介绍了使用创建的 tensorflow 模型进行预测的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在查看这篇 Tensorflow 文章中的源代码,该文章讨论了如何创建广泛而深入的学习模型.https://www.tensorflow.org/versions/r1.3/tutorials/广而深

I'm looking at source code from this Tensorflow article that talks about how to create a wide-and-deep learning model. https://www.tensorflow.org/versions/r1.3/tutorials/wide_and_deep

这里是python源代码的链接:https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py

Here is the link to the python source code: https://github.com/tensorflow/tensorflow/blob/r1.3/tensorflow/examples/learn/wide_n_deep_tutorial.py

它的目标是训练一个模型,根据人口普查信息中的数据,预测某人每年的收入是高于还是低于 5 万美元.

What the goal of it is, is to train a model that will predict if someone makes more or less than $50k a year given the data in the census information.

按照指示,我正在运行此命令来执行:

As instructed, I'm running this command to execute:

python wide_n_deep_tutorial.py --model_type=wide_n_deep

我得到的结果如下:

$ python wide_n_deep.py --model_type=wide_n_deep
Training data is downloaded to /tmp/tmp_pwqo2h8
Test data is downloaded to /tmp/tmph6jcimik
2018-01-03 05:34:12.236038: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
WARNING:tensorflow:enqueue_data was called with num_epochs and num_threads > 1. num_epochs is applied per thread, so this will produce more epochs than you probably intend. If you want to limit epochs, use one thread.
WARNING:tensorflow:enqueue_data was called with shuffle=False and num_threads > 1. This will create multiple threads, all reading the array/dataframe in order. If you want examples read in order, use one thread; if you want multiple threads, enable shuffling.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
WARNING:tensorflow:Casting <dtype: 'float32'> labels to bool.
model directory = /tmp/tmp_ab6cfsf
accuracy: 0.808673
accuracy_baseline: 0.763774
auc: 0.841373
auc_precision_recall: 0.66043
average_loss: 0.418642
global_step: 2000
label/mean: 0.236226
loss: 41.8154
prediction/mean: 0.251593

在我在网上看到的各种文章中,它讨论了在 .ckpt 文件中加载.当我查看我的模型目录时,我看到了这些文件:

In the various articles that I've seen online, it talks about loading in a .ckpt file. When I look in my model directory I see these files:

$ ls /tmp/tmp_ab6cfsf
checkpoint  eval  events.out.tfevents.1514957651.ml-1  graph.pbtxt  model.ckpt-1.data-00000-of-00001  model.ckpt-1.index  model.ckpt-1.meta  model.ckpt-2000.data-00000-of-00001  model.ckpt-2000.index  model.ckpt-2000.meta

我猜我要使用的是 model.ckpt-1.meta,对吗?

I'm guessing the one that I would be using is model.ckpt-1.meta, is that correct?

但我也对如何使用和提供此模型数据感到困惑.我在 Tensorflow 的网站上看过这篇文章:https://www.tensorflow.org/versions/r1.3/programmers_guide/saved_model

But I'm also confused on how to use and feed this model data. I've looked at this article on Tensorflow's website: https://www.tensorflow.org/versions/r1.3/programmers_guide/saved_model

上面写着请注意,Estimators 会自动保存和恢复变量(在 model_dir 中)."(不确定在这种情况下是什么意思)

Which says "Note that Estimators automatically saves and restores variables (in the model_dir)." (not sure what that means in this context)

除了薪水之外,我怎样才能以人口普查数据的格式生成信息,因为这是我们应该预测的内容?我不清楚如何使用两篇 Tensorflow 文章以便能够使用经过训练的模型进行预测.

How can I generate information in the format of the census data, except the salary since that is what we are supposed to be predicting? It's not obvious to me how to use the two Tensorflow articles in order to be able to use the trained model in order to make predictions.

推荐答案

可以看官方博文(第 1 部分第 3 部分)很好地解释了如何使用估算器.

You can look at the official blog posts (part 1 and part 3) from the TensorFlow team that explains well how to use an estimator.

他们特别解释了如何使用自定义输入进行预测.这使用了 Estimators 的内置 predict 方法:

In particular they explain how to make predictions using a custom input. This uses the built-in predict method of Estimators:

estimator = tf.estimator.Estimator(model_fn, ...)

predict_input_fn = ...  # define this using tf.data

predict_results = estimator.predict(predict_input_fn)
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))

<小时>

对于您的示例,我们可以使用附加的 csv 文件创建预测输入函数.假设我们有一个名为 "predict.csv" 的 csv 文件,其中包含三个示例(例如,可以是 "test.csv" 的前三行没有标签).这会给:


For your example, we can create a predict input function using an additional csv file. Let's suppose we have a csv file called "predict.csv" containing three examples (could be the first three lines of "test.csv" for instance without the labels). This would give:

predict.csv:

...跳过这一行...
25, 私人, 226802, 11, 7, 未婚, 机器检查, 自己的孩子, 黑人, 男性, 0, 0, 40, 美国
38, 私人, 89814, HS-grad, 9, 已婚公民-配偶, 农业-渔业, 丈夫, 白人, 男性, 0, 0, 50, 美国
28, Local-gov, 336951, Assoc-acdm, 12, Married-civ-spouse, Protective-serv, 丈夫, 白人, 男性, 0, 0, 40, 美国

...skip this line...
25, Private, 226802, 11th, 7, Never-married, Machine-op-inspct, Own-child, Black, Male, 0, 0, 40, United-States
38, Private, 89814, HS-grad, 9, Married-civ-spouse, Farming-fishing, Husband, White, Male, 0, 0, 50, United-States
28, Local-gov, 336951, Assoc-acdm, 12, Married-civ-spouse, Protective-serv, Husband, White, Male, 0, 0, 40, United-States

estimator = build_estimator(FLAGS.model_dir, FLAGS.model_type)

def predict_input_fn(data_file):
    """Input builder function."""
    df_data = pd.read_csv(
        tf.gfile.Open(data_file),
        names=CSV_COLUMNS[:-1],  # remove the last name "income_bracket" that corresponds to the label
        skipinitialspace=True,
        engine="python",
        skiprows=1)
    # remove NaN elements
    df_data = df_data.dropna(how="any", axis=0)
    return tf.estimator.inputs.pandas_input_fn(x=df_data, y=None, shuffle=False)

predict_file_name = "wide_n_deep/predict.csv"
predict_results = estimator.predict(input_fn=predict_input_fn(predict_file_name))
for idx, prediction in enumerate(predict_results):
    print(idx)
    for key in prediction:
        print("...{}: {}".format(key, prediction[key]))

这篇关于使用创建的 tensorflow 模型进行预测的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆