如何使 Tensorflow 教程中从 Imagenet (classify_image.py) 预训练的 inception-v3 模型可作为模块导入? [英] How can I make the inception-v3 model pre-trained from Imagenet (classify_image.py) in the Tensorflow tutorial importable as a module?

查看:67
本文介绍了如何使 Tensorflow 教程中从 Imagenet (classify_image.py) 预训练的 inception-v3 模型可作为模块导入?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想知道如何修改classify_image.py(来自本教程,以便我可以从另一个 python 脚本导入它.我基本上希望它具有与它已有的功能相同的功能,但不是提供图像路径并在终端中打印出响应,我想给一个函数图像路径并获取返回前 5 个结果及其概率的函数.

I wonder how I can modify classify_image.py (from this tutorial so that I can import it from another python script. I would basically like it to have the same functionality it already has, but instead of providing the image path and getting the response printed out in the terminal, I would like to give a function the image path and get the function to return the top 5 results with their probabilities.

我还没有找到这个问题的直接解决方案,但我意识到我解决问题和搜索以前的答案是有限的,因为不幸的是我还没有学习 Tensorflow 的基础知识.

I haven’t found a direct solution to this problem yet, but I realize my problem solving and search for previous answers are limited since I unfortunately haven’t learned the basics of Tensorflow yet.

当然,如果有另一个预训练的 Tensorflow 模型同样好并且满足我的要求,我会很乐意使用它.

Of course, if there is another pre-trained Tensorflow model that is just as good and meets my demands I would happily use that instead.

问候,本都

更新也许我应该澄清一下:

我不想训练模型,只需使用预训练的模型进行图像识别,在这种情况下,我有一个图像识别脚本,我可以将其作为模块导入到另一个 Python 应用程序中.

I don't want to train a model, just use a pre-trained one for image recognition, and in this case have an image recognition script that I could import as a module in another python application.

我也尝试过使用本教程中的代码但我也被困在那里,在这种情况下,它包括很多手动安装,我可能在某些步骤中失败了.classify_image.py 示例 的好处是我让它按预期工作到教程中,所以我认为从那一步到将其用作可插拔模块的步骤不应该那么大.

I have also tried with code from this tutorial but I got stuck there too, and in that case it includes a lot of manual installation where I might have failed in some step. The good thing with the classify_image.py example is that I got it to work as intended to in the tutorial so I thought that the step from that to using it as a pluggable module shouldn't be that big.

我尝试过的(使用classify_image.py)是将if __name__ = '__main__' 下的行移动到main(_) 以便它们被执行当我从另一个脚本调用它们但我仍然遇到问题时.我主要遇到了 main(_) 函数的问题,它希望我向它传递一个参数,并且通过搜索我认为 _ 似乎是某种占位符从 cli 获取输入时使用.所有 FLAGS 的东西似乎也与 cli 相关,这就是我想要摆脱的.我也不确定模型权重等是否正确保存以便我能够从另一个脚本使用它.同样,此时我只想尝试使用图像分类器,并希望进一步了解其背后的机器学习.抱歉,我对这方面的基础知识缺乏了解!

What I have tried (with classify_image.py) is moving the lines beneath if __name__ = '__main__' to main(_) in order for them to get executed when I call them from another script but I am still having problems. I am mainly having problems with the main(_) function, that wants me to pass it an argument, and from searching around I figured _ seems to be some kind of placeholder used when getting input from the cli. All the FLAGS stuff seems to be cli related too, which is what I want to move away from. I am also unsure whether the model weights etc get saved correctly for me to be able to use it from another script. Again, at this point I just want to play around with the image classifier and further on hopefully learn more about the machine learning behind it. Sorry for my lack of knowledge in the basics of this!

classify_image.py:

classify_image.py:

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Simple image classification with Inception.
Run image classification with Inception trained on ImageNet 2012 Challenge data
set.
This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.
Change the --image_file argument to any jpg image to compute a
classification of that image.
Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.
https://tensorflow.org/tutorials/image_recognition/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = None

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self,
               label_lookup_path=None,
               uid_lookup_path=None):
    if not label_lookup_path:
      label_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
    if not uid_lookup_path:
      uid_lookup_path = os.path.join(
          FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
    self.node_lookup = self.load(label_lookup_path, uid_lookup_path)

  def load(self, label_lookup_path, uid_lookup_path):
    """Loads a human readable English name for each softmax node.
    Args:
      label_lookup_path: string UID to integer node ID.
      uid_lookup_path: string UID to human-readable string.
    Returns:
      dict from integer node ID to human-readable string.
    """
    if not tf.gfile.Exists(uid_lookup_path):
      tf.logging.fatal('File does not exist %s', uid_lookup_path)
    if not tf.gfile.Exists(label_lookup_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
    uid_to_human = {}
    p = re.compile(r'[n\d]*[ \S,]*')
    for line in proto_as_ascii_lines:
      parsed_items = p.findall(line)
      uid = parsed_items[0]
      human_string = parsed_items[2]
      uid_to_human[uid] = human_string

    # Loads mapping from string UID to integer node ID.
    node_id_to_uid = {}
    proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    for line in proto_as_ascii:
      if line.startswith('  target_class:'):
        target_class = int(line.split(': ')[1])
      if line.startswith('  target_class_string:'):
        target_class_string = line.split(': ')[1]
        node_id_to_uid[target_class] = target_class_string[1:-2]

    # Loads the final mapping of integer node ID to human-readable string
    node_id_to_name = {}
    for key, val in node_id_to_uid.items():
      if val not in uid_to_human:
        tf.logging.fatal('Failed to locate: %s', val)
      name = uid_to_human[val]
      node_id_to_name[key] = name

    return node_id_to_name

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


def create_graph():
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  with tf.gfile.FastGFile(os.path.join(
      FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image):
  """Runs inference on an image.
  Args:
    image: Image file name.
  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = tf.gfile.FastGFile(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph()

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    predictions = sess.run(softmax_tensor,
                           {'DecodeJpeg/contents:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup()

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      score = predictions[node_id]
      print('%s (score = %.5f)' % (human_string, score))


def maybe_download_and_extract():
  """Download and extract model tar file."""
  dest_directory = FLAGS.model_dir
  if not os.path.exists(dest_directory):
    os.makedirs(dest_directory)
  filename = DATA_URL.split('/')[-1]
  filepath = os.path.join(dest_directory, filename)
  if not os.path.exists(filepath):
    def _progress(count, block_size, total_size):
      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
          filename, float(count * block_size) / float(total_size) * 100.0))
      sys.stdout.flush()
    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
    print()
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def main(_):
  maybe_download_and_extract()
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_dir',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to classify_image_graph_def.pb,
      imagenet_synset_to_human_label_map.txt, and
      imagenet_2012_challenge_label_map_proto.pbtxt.\
      """
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

推荐答案

最后,我设法使用了原始问题更新中提到的 SO 文章中的代码.我使用来自上述 SO 问题的答案的附加 im = 2*(im/255.0)-1.0 修改了代码,一些行来修复我计算机上的 PIL 加上一个将类转换为人类可读的函数标签(在 github 上找到),链接到下面的那个文件.我把它变成了一个可调用的函数,它将图像列表作为输入并输出一个标签列表和预测值.如果您想使用它,您必须这样做:

In the end I managed to use the code from the SO article reffered to in the update in the original question. I modified the code with the additional im = 2*(im/255.0)-1.0 from the answer of said SO question, some line to fix PIL on my computer plus a function to convert classes to human readable labels (found on github), link to that file below. I made it a callable function that takes a list of images as input and outputs a list of labels and predict values. If you'd like to use it, this is what you have to to:

  1. 安装最新的 Tensorflow 版本(目前需要 1.0).
  2. git clone https://github.com/tensorflow/models/你想要模型的地方.
  3. 这个检查点文件放在我之前提到的 SO 问题中(需要解压,当然)在你的项目目录中.
  4. 此文本文件(人类可读的标签)放在您的项目目录中.
  5. 使用 SO 问题中的这段代码,并在我这边进行一些修改,将其放入项目中的 .py 文件中:

  1. Install the latest Tensorflow version (1.0 at the moment, which is needed).
  2. git clone https://github.com/tensorflow/models/where you want the models.
  3. Put this checkpoint file from the SO question I referred to earlier (needs to be extracted, of course) in the directory of your project.
  4. Put this text file (the human readable labels) in the directory of your project.
  5. Use this code from the SO question with some modifications from my side, put it in a .py file in your project:

import tensorflow as tf
slim = tf.contrib.slim
import PIL as pillow
from PIL import Image
#import Image
from inception_resnet_v2 import *
import numpy as np

with open('imagenet1000_clsid_to_human.txt','r') as inf:
    imagenet_classes = eval(inf.read())

def get_human_readable(id):
    id = id - 1
    label = imagenet_classes[id]

    return label

checkpoint_file = './inception_resnet_v2_2016_08_30.ckpt'

#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])  
with slim.arg_scope(arg_scope):
    logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)

def classify_image(sample_images):
    classifications = []
    for image in sample_images:
        im = Image.open(image).resize((299,299))
        im = np.array(im)
        im = im.reshape(-1,299,299,3)
        im = 2*(im/255.0)-1.0
        predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
        #print (np.max(predict_values), np.max(logit_values))
        #print (np.argmax(predict_values), np.argmax(logit_values))
        label = get_human_readable(np.argmax(predict_values))
        predict_value = np.max(predict_values)
        classifications.append({"label":label, "predict_value":predict_value})

    return classifications

这篇关于如何使 Tensorflow 教程中从 Imagenet (classify_image.py) 预训练的 inception-v3 模型可作为模块导入?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆