如何从 Tensorflow 对象检测 API 正确地提供对象检测模型? [英] How to properly serve an object detection model from Tensorflow Object Detection API?

查看:53
本文介绍了如何从 Tensorflow 对象检测 API 正确地提供对象检测模型?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Tensorflow 对象检测 API(github.com/tensorflow/models/tree/master/object_detection) 完成一项对象检测任务.现在我在使用 Tensorflow Serving(tensorflow.github.io/serving/) 训练的检测模型提供服务时遇到问题.

I am using Tensorflow Object Detection API(github.com/tensorflow/models/tree/master/object_detection) with one object detection task. Right now I am having problem on serving the detection model I trained with Tensorflow Serving(tensorflow.github.io/serving/).

1. 我遇到的第一个问题是关于将模型导出到可处理文件.对象检测 api 友好地包含了导出脚本,以便我能够将 ckpt 文件转换为带有变量的 pb 文件.但是,输出文件在变量"文件夹中不会有任何内容.我虽然这是一个错误并在 Github 上报告了它,但似乎他们将变量转换为常量以便没有变量.可以在此处找到详细信息.

1. The first issue I am encountering is about exporting the model to servable files. The object detection api kindly included the export script so that I am able to convert ckpt files to pb files with variables. However, the output files will not have any content in 'variables' folder. I though this was a bug and reported it on Github, but it seems they interned to convert variables to constants so that there will be no variables. The detail can be found HERE.

我在导出保存的模型时使用的标志如下:

The flags I was using when exporting the saved model is as follows:

    CUDA_VISIBLE_DEVICES=0 python export_inference_graph.py \
        --input_type image_tensor \
            --pipeline_config_path configs/rfcn_resnet50_car_Jul_20.config \
                --checkpoint_path resnet_ckpt/model.ckpt-17586 \
                    --inference_graph_path serving_model/1 \
                      --export_as_saved_model True

当我将 --export_as_saved_model 切换为 False 时,它​​在 python 中运行得非常好.

It runs perfectly fine in python when I switch --export_as_saved_model to False.

但是,我仍然无法为模型提供服务.

But still, I am having issue with serving the model.

当我试图跑步时:

~/serving$ bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --port=9000 --model_name=gan --model_base_path=<my_model_path>

我得到了:

2017-07-27 16:11:53.222439: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:155] Restoring SavedModel bundle.
2017-07-27 16:11:53.222497: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:165] The specified SavedModel has no variables; no checkpoints were restored.
2017-07-27 16:11:53.222502: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running LegacyInitOp on SavedModel bundle.
2017-07-27 16:11:53.229463: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:284] Loading SavedModel: success. Took 281805 microseconds.
2017-07-27 16:11:53.229508: I tensorflow_serving/core/loader_harness.cc:86] Successfully loaded servable version {name: gan version: 1}
2017-07-27 16:11:53.244716: I tensorflow_serving/model_servers/main.cc:290] Running ModelServer at 0.0.0.0:9000 ...

我认为模型没有正确加载,因为它显示指定的 SavedModel 没有变量;没有恢复检查点."

I think the model was not properly loaded since it shows "The specified SavedModel has no variables; no checkpoints were restored."

但是既然我们已经把所有的变量都转换成了常量,就显得合理了.我在这里不确定.

But since we have converted all variables into constants, it seems reasonable. I am not sure here.

2.我无法使用客户端调用服务器并对示例图像进行检测.

2. I was not able to use client to call server and do detection on a sample image.

客户端脚本如下:

from __future__ import print_function
from __future__ import absolute_import

# Communication to TensorFlow server via gRPC
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
from PIL import Image
# TensorFlow serving stuff to send messages
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2


# Command line arguments
tf.app.flags.DEFINE_string('server', 'localhost:9000',
                       'PredictionService host:port')
tf.app.flags.DEFINE_string('image', '', 'path to image in JPEG format')
FLAGS = tf.app.flags.FLAGS


def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
    (im_height, im_width, 3)).astype(np.uint8)

def main(_):
    host, port = FLAGS.server.split(':')
    channel = implementations.insecure_channel(host, int(port))
    stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

    # Send request
    request = predict_pb2.PredictRequest()
    image = Image.open(FLAGS.image)
    image_np = load_image_into_numpy_array(image)
    image_np_expanded = np.expand_dims(image_np, axis=0)
    # Call GAN model to make prediction on the image
    request.model_spec.name = 'gan'
    request.model_spec.signature_name = 'predict_images'
    request.inputs['inputs'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image_np_expanded))

    result = stub.Predict(request, 60.0)  # 60 secs timeout
    print(result)


if __name__ == '__main__':
    tf.app.run()

为了匹配 request.model_spec.signature_name = 'predict_images',我修改了对象检测 api (github.com/tensorflow/models/blob/master/object_detection/exporter.py) 中的 exporter.py 脚本.py) 从第 289 行开始:

To match request.model_spec.signature_name = 'predict_images', I modified the exporter.py script in object detection api (github.com/tensorflow/models/blob/master/object_detection/exporter.py) started at line 289 from:

          signature_def_map={
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

致:

          signature_def_map={
          'predict_images': detection_signature,
          signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              detection_signature,
      },

因为我不知道如何调用默认签名密钥.

Since I have no idea how to call a default signature key.

当我运行以下命令时:

bazel-bin/tensorflow_serving/example/client --server=localhost:9000 --image=<my_image_file>

我收到以下错误消息:

    Traceback (most recent call last):
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 54, in <module>
    tf.app.run()
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "/home/xinyao/serving/bazel-bin/tensorflow_serving/example/client.runfiles/tf_serving/tensorflow_serving/example/client.py", line 49, in main
    result = stub.Predict(request, 60.0)  # 60 secs timeout
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 324, in __call__
    self._request_serializer, self._response_deserializer)
  File "/usr/local/lib/python2.7/dist-packages/grpc/beta/_client_adaptations.py", line 210, in _blocking_unary_unary
    raise _abortion_error(rpc_error_call)
grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.NOT_FOUND, details="FeedInputs: unable to find feed output ToFloat:0")

不太清楚这里发生了什么.

Not quite sure what's going on here.

最初我虽然可能我的客户端脚本不正确,但在我发现 AbortionError 来自 github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc 之后.似乎我在构建图表时遇到了这个错误.所以它可能是由我遇到的第一个问题引起的.

Initially I though maybe my client script is not correct, after I found the AbortionError is from github.com/tensorflow/tensorflow/blob/f488419cd6d9256b25ba25cbe736097dfeee79f9/tensorflow/core/graph/subgraph.cc. Seems I got this error when building the graph. So it might be caused by the first issue I have.

我是这个东西的新手,所以我真的很困惑.我想我可能一开始就错了.有什么方法可以正确导出并为检测模型提供服务?任何建议都会有很大帮助!

I am new to this stuff, so I am really confused. I think I might be wrong at start. Is there any way that I could properly export and serve the detection model? Any suggestions will be of great help!

推荐答案

当前导出器代码未正确填充签名字段.所以使用模型服务器服务是行不通的.对此表示歉意.一个更好地支持导出模型的新版本即将推出.它包括服务所需的一些重要修复和改进,尤其是在 Cloud ML Engine 上服务.如果您想尝试早期版本,请参阅 github 问题.

The current exporter code doesn't populate signature field properly. So serving using model server doesn't work. Apologies to that. A new version to better support exporting the model is coming. It includes some important fixes and improvements needed for serving, especially serving on Cloud ML Engine. See the github issue if you want to try an early version of it.

对于指定的 SavedModel 没有变量;没有恢复检查点."消息,由于您所说的确切原因,这是预期的,因为所有变量都转换为图中的常量.对于FeedInputs:无法找到提要输出ToFloat:0"的错误,请确保在构建模型服务器时使用TF 1.2.

For "The specified SavedModel has no variables; no checkpoints were restored." message, it is expected due to the exact reason you said, as all variables are converted into constants in the graph. For the error of "FeedInputs: unable to find feed output ToFloat:0", make sure you use TF 1.2 when building the model server.

这篇关于如何从 Tensorflow 对象检测 API 正确地提供对象检测模型?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆