TensorFlow:如何从SavedModel进行预测? [英] TensorFlow: How to predict from a SavedModel?

查看:889
本文介绍了TensorFlow:如何从SavedModel进行预测?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经导出的现在我与加载它放回并作出预测.它被训练,具有以下特性和标签:

I have exported a SavedModel and now I with to load it back in and make a prediction. It was trained with the following features and labels:

F1 : FLOAT32
F2 : FLOAT32
F3 : FLOAT32
L1 : FLOAT32

所以说我想输入值20.9, 1.8, 0.9以获得单个FLOAT32预测.我该如何完成?我已经设法成功加载模式,但我不知道如何访问这些信息,使预测的呼叫.

So say I want to feed in the values 20.9, 1.8, 0.9 get a single FLOAT32 prediction. How do I accomplish this? I have managed to successfully load the model, but I am not sure how to access it to make the prediction call.

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    # How can I predict from here?
    # I want to do something like prediction = model.predict([20.9, 1.8, 0.9])

此问题不是此处发布的问题的重复.这个问题的重点是关于执行推断的最小示例的任何模型类的(不只是限定于)和指定的输入和输出节点名称的语法.

This question is not a duplicate of the question posted here. This question focuses on a minimal example of performing inference on a SavedModel of any model class (not just limited to tf.estimator) and the syntax of specifying input and output node names.

推荐答案

在图形被加载,这是在当前上下文中可用的并且可以通过其馈送的输入数据以获得预测.每个用例是相当不同的,但除了你的代码看起来是这样的:

Once the graph is loaded, it is available in the current context and you can feed input data through it to obtain predictions. Each use-case is rather different, but the addition to your code will look something like this:

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        "/job/export/Servo/1503723455"
    )

    prediction = sess.run(
        'prefix/predictions/Identity:0',
        feed_dict={
            'Placeholder:0': [20.9],
            'Placeholder_1:0': [1.8],
            'Placeholder_2:0': [0.9]
        }
    )

    print(prediction)

在这里,你需要知道你的预测输入将是什么名字.如果你没有给他们在教堂中殿的,然后他们默认为<6>,其中是第n个功能.

Here, you need to know the names of what your prediction inputs will be. If you did not give them a nave in your serving_fn, then they default to Placeholder_n, where n is the nth feature.

的第一个字符串参数是预测对象的名称.这将根据您的用例而有所不同.

The first string argument of sess.run is the name of the prediction target. This will vary based on your use case.

这篇关于TensorFlow:如何从SavedModel进行预测?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆