TensorFlow:tf.Estimator模型的输入节点是什么 [英] TensorFlow: What are the input nodes for tf.Estimator models

查看:113
本文介绍了TensorFlow:tf.Estimator模型的输入节点是什么的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我训练了宽广的使用预制的Estimator类( DNNLinearCombinedClassifier )进行深度模型基本上遵循tensorflow.org上的教程

I trained a Wide & Deep model using the pre-made Estimator class (DNNLinearCombinedClassifier), by essentially following the tutorial on tensorflow.org.

我想进行推理/服务,但不使用张量流服务。这基本上归结为将一些测试数据馈送到正确的输入张量并检索输出张量。

I wanted to do inference/serving, but without using tensorflow-serving. This basically comes down to feeding some test data to the correct input tensor and retrieving the output tensor.

但是,我不确定输入节点/层应该是什么。在tensorflow图(graph.pbtxt)中,以下节点似乎相关。但是它们也与主要在训练期间使用的输入队列有关,但不一定是推理(我一次只能发送一个实例)。

However, I am not sure what the input nodes/layer should be. In the tensorflow graph (graph.pbtxt), the following nodes seem relevant. But they are also related to the input queue which is mainly used during training, but not necessarily inference (I can just send one instance at a time).

  name: "enqueue_input/random_shuffle_queue"
  name: "enqueue_input/Placeholder"
  name: "enqueue_input/Placeholder_1"
  name: "enqueue_input/Placeholder_2"
  ...
  name: "enqueue_input/Placeholder_84"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_1"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_2"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_3"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany_4"
  name: "enqueue_input/random_shuffle_queue_EnqueueMany"
  name: "enqueue_input/sub/y"
  name: "enqueue_input/sub"
  name: "enqueue_input/Maximum/x"
  name: "enqueue_input/Maximum"
  name: "enqueue_input/Cast"
  name: "enqueue_input/mul/y"
  name: "enqueue_input/mul"

有人知道答案吗?

推荐答案

如果您想进行推理,但不使用tensorflow-serving,则可以使用 tf.estimator.Estimator 预测方法。

If you want inference, but without using tensorflow-serving, you can just use the tf.estimator.Estimator predict method.

但是,如果要手动执行操作(这样可以更快地运行),则需要一种解决方法。我不确定我所做的是否是最好的方法,但是它确实有效。这是我的解决方案。

But if you want to do it manually (so that is runs faster), you need a workaround. I am not sure if what I did was exactly the best approach, but it worked. Here's my solution.

1)让我们进行导入并创建变量和伪数据:

1) Let's do the imports and create variables and fake data:

import os
import numpy as np
from functools import partial
import pickle
import tensorflow as tf

N = 10000
EPOCHS = 1000
BATCH_SIZE = 2

X_data = np.random.random((N, 10))
y_data = (np.random.random((N, 1)) >= 0.5).astype(int)

my_dir = os.getcwd() + "/"

2)定义一个input_fn,您将使用 tf.data.Dataset 。将张量名称保存在字典中( input_tensor_map),该字典将输入键映射到张量名称。

2) Define an input_fn, which you will use tf.data.Dataset. Save the tensor names in a dictionary ("input_tensor_map"), which maps the input key to the tensor name.

def my_input_fn(X, y=None, is_training=False):

    def internal_input_fn(X, y=None, is_training=False):

        if (not isinstance(X, dict)):
            X = {"x": X}

        if (y is None):
            dataset = tf.data.Dataset.from_tensor_slices(X)
        else:
            dataset = tf.data.Dataset.from_tensor_slices((X, y))

        if (is_training):
            dataset = dataset.repeat().shuffle(100)
            batch_size = BATCH_SIZE
        else:
            batch_size = 1

        dataset = dataset.batch(batch_size)

        dataset_iter = dataset.make_initializable_iterator()

        if (y is None):
            features = dataset_iter.get_next()
            labels = None
        else:
            features, labels = dataset_iter.get_next()

        input_tensor_map = dict()
        for input_name, tensor in features.items():
            input_tensor_map[input_name] = tensor.name

        with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'wb') as f:
            pickle.dump(input_tensor_map, f, protocol=pickle.HIGHEST_PROTOCOL)

        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, dataset_iter.initializer)

        return (features, labels) if (not labels is None) else features

    return partial(internal_input_fn, X=X, y=y, is_training=is_training)

3)定义模型,用于您的 tf.estimator.Estimator 中。例如:

3) Define your model, to be used in your tf.estimator.Estimator. For example:

def my_model_fn(features, labels, mode):

    output = tf.layers.dense(inputs=features["x"], units=1, activation=None)
    logits = tf.identity(output, name="logits")
    prediction = tf.nn.sigmoid(logits, name="predictions")
    classes = tf.to_int64(tf.greater(logits, 0.0), name="classes")

    predictions_dict = {
                "class": classes,
                "probabilities": prediction
                }

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions_dict)

    one_hot_labels = tf.squeeze(tf.one_hot(tf.cast(labels, dtype=tf.int32), 2))
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=one_hot_labels, logits=logits)

    tf.summary.scalar("loss", loss)

    accuracy = tf.reduce_mean(tf.to_float(tf.equal(labels, classes)))
    tf.summary.scalar("accuracy", accuracy)

    # Configure the Training Op (for TRAIN mode)
    if (mode == tf.estimator.ModeKeys.TRAIN):
        train_op = tf.train.AdamOptimizer().minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

4)训练并冻结您的模型。冻结方法来自 TensorFlow:如何冻结模型并使用python API进行服务,我对此做了微小的修改。

4) Train and freeze your model. The freeze method is from TensorFlow: How to freeze a model and serve it with a python API, which I added a tiny modification.

def freeze_graph(output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    """
    if (output_node_names is None):
        output_node_names = 'loss'

    if not tf.gfile.Exists(my_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % my_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(my_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

# *****************************************************************************

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

tf.logging.set_verbosity(tf.logging.INFO)

estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
    estimator.train(input_fn=my_input_fn(X=X_data, y=y_data, is_training=True), steps=EPOCHS)
    freeze_graph("predictions,classes")

5 )最后,您可以使用冻结的图进行推断,输入张量名称在您保存的字典中。同样,从 TensorFlow:如何冻结模型并使用python API服务

5) Finally, you can use the frozen graph for inference, input tensors names are in the dictionary that you saved. Again, the method to load the freezed model from TensorFlow: How to freeze a model and serve it with a python API.

def load_frozen_graph(prefix="frozen_graph"):
    frozen_graph_filename = os.path.join(my_dir, "frozen_model.pb")

    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name=prefix)

    return graph

# *****************************************************************************

X_test = {"x": np.random.random((int(N/2), 10))}

prefix = "frozen_graph"
graph = load_frozen_graph(prefix)

for op in graph.get_operations():
    print(op.name)

with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'rb') as f:
    input_tensor_map = pickle.load(f)

with tf.Session(graph=graph) as sess:
    input_feed = dict()

    for key, tensor_name in input_tensor_map.items():
        tensor = graph.get_tensor_by_name(prefix + "/" + tensor_name)
        input_feed[tensor] = X_test[key]

    logits = graph.get_operation_by_name(prefix + "/logits").outputs[0]
    probabilities = graph.get_operation_by_name(prefix + "/predictions").outputs[0]
    classes = graph.get_operation_by_name(prefix + "/classes").outputs[0]

    logits_values, probabilities_values, classes_values = sess.run([logits, probabilities, classes], feed_dict=input_feed)

这篇关于TensorFlow:tf.Estimator模型的输入节点是什么的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆