使用python API进行的培训作为Java API中LabelImage模块的输入? [英] Using training made with python API as input to LabelImage module in java API?

查看:157
本文介绍了使用python API进行的培训作为Java API中LabelImage模块的输入?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有java tensorflow API的问题。我使用python tensorflow API运行训练,生成文件output_graph.pb和output_labels.txt。现在由于某种原因,我想使用这些文件作为java tensorflow API中LabelImage模块的输入。我认为一切都会正常工作,因为该模块只需要一个.pb和一个.txt。然而,当我运行模块时,我收到此错误:

I have a problem with java tensorflow API. I have run the training using the python tensorflow API, generating the files output_graph.pb and output_labels.txt. Now for some reason I want to use those files as input to the LabelImage module in java tensorflow API. I thought everything would have worked fine since that module wants exactly one .pb and one .txt. Nevertheless, when I run the module, I get this error:

2017-04-26 10:12:56.711402: W tensorflow/core/framework/op_def_util.cc:332] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
Exception in thread "main" java.lang.IllegalArgumentException: No Operation named [input] in the Graph
at org.tensorflow.Session$Runner.operationByName(Session.java:343)
at org.tensorflow.Session$Runner.feed(Session.java:137)
at org.tensorflow.Session$Runner.feed(Session.java:126)
at it.zero11.LabelImage.executeInceptionGraph(LabelImage.java:115)
at it.zero11.LabelImage.main(LabelImage.java:68)

如果你帮助我找到问题所在,我将非常感激。此外,我想问你是否有办法从java tensorflow API运行培训,因为这会使事情变得更容易。

I would be very grateful if you help me finding where the problem is. Furthermore I want to ask you if there is a way to run the training from java tensorflow API, because that would make things easier.

更确切地说:

事实上,我不使用自编代码,至少对于相关步骤。我所做的就是使用这个模块进行培训, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py ,将其包含的目录包含根据其描述在子目录之间划分的图像。特别是,我认为这些是生成输出的行:

As a matter of fact, I do not use self-written code, at least for the relevant steps. All I have done is doing the training with this module, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py, feeding it with the directory that contains the images divided among subdirectories according to their description. In particular, I think these are the lines that generate the outputs:

output_graph_def = graph_util.convert_variables_to_constants(
    sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
with gfile.FastGFile(FLAGS.output_graph, 'wb') as f:
  f.write(output_graph_def.SerializeToString())
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
  f.write('\n'.join(image_lists.keys()) + '\n')

然后,我将输出(一个some_graph.pb和一个some_labels.txt)作为此java模块的输入: https:// github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java ,替换默认输入。我得到的错误是上面报告的错误。

Then, I give the outputs (one some_graph.pb and one some_labels.txt) as input to this java module: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java, replacing the default inputs. The error I get is the one reported above.

推荐答案

LabelImage.java中默认使用的模型与模型不同正在重新训练,因此输入和输出节点的名称不对齐。请注意,TensorFlow模型是图形, feed() fetch()的参数是图中节点的名称。因此,您需要知道适合您模型的名称。

The model used by default in LabelImage.java is different that the model that is being retrained, so the names of inputs and output nodes do not align. Note that TensorFlow models are graphs and the arguments to feed() and fetch() are names of nodes in the graph. So you need to know the names appropriate for your model.

查看 retrain.py ,似乎它有一个节点,它将JPEG文件的原始内容作为输入(节点 DecodeJpeg / contents )并在节点中生成一组标签 final_result

Looking at retrain.py, it seems that it has a node that takes the raw contents of a JPEG file as input (the node DecodeJpeg/contents) and produces the set of labels in the node final_result.

如果是这种情况,那么你将在Java中执行类似下面的操作(并且你不需要构造图形的位来规范化图像,因为它似乎是再培训模型的一部分,所以替换 LabelImage.java:64 类似于:

If that's the case, then you'd do something like the following in Java (and you don't need the bit that constructs a graph to normalize the image since that seems to be a part of the retrained model, so replace LabelImage.java:64 with something like:

try (Tensor image = Tensor.create(imageBytes);
     Graph g = new Graph()) {
  g.importGraphDef(graphDef);
  try (Session s = new Session(g);
    // Note the change to the name of the node and the fact
    // that it is being provided the raw imageBytes as input
    Tensor result = s.runner().feed("DecodeJpeg/contents", image).fetch("final_result").run().get(0)) {
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
      throw new RuntimeException(
          String.format(
              "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
              Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float[] probabilities = result.copyTo(new float[1][nlabels])[0];
    // At this point nlabels = number of classes in your retrained model
    DoSomethingWith(probabilities);
  }
}

希望有所帮助。

这篇关于使用python API进行的培训作为Java API中LabelImage模块的输入?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆