Tensorflow 2.0 &Java API [英] Tensorflow 2.0 & Java API

查看:43
本文介绍了Tensorflow 2.0 &Java API的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

(注意,我已经解决了我的问题并将代码贴在底部)

(note, I've resolved my problem and posted the code at the bottom)

我正在使用 TensorFlow,后端处理必须在 Java 中进行.我采用了 https://developers.google.com/中的模型之一machine-learning/crash-course 并用 tf.saved_model.save(my_model,"house_price_median_income") 保存(使用 docker 容器).我复制了模型并将其加载到 Java 中(使用从源代码构建的 2.0 东西,因为我在 Windows 上).我可以加载模型并运行它:

I'm playing around with TensorFlow and the backend processing must take place in Java. I've taken one of the models from the https://developers.google.com/machine-learning/crash-course and saved it with tf.saved_model.save(my_model,"house_price_median_income") (using a docker container). I copied the model off and loaded it into Java (using the 2.0 stuff built from source because I'm on windows). I can load the model and run it:

   try (SavedModelBundle model = SavedModelBundle.load("./house_price_median_income", "serve")) {
    try (Session session = model.session()) {
        Session.Runner runner = session.runner();
        float[][] in = new float[][]{ {2.1518f} } ;

        Tensor<?> jack = Tensor.create(in);
        runner.feed("serving_default_layer1_input", jack);

        float[][] probabilities = runner.fetch("StatefulPartitionedCall").run().get(0).copyTo(new float[1][1]);

        for (int i = 0; i < probabilities.length; ++i) {
            System.out.println(String.format("-- Input #%d", i));
            for (int j = 0; j < probabilities[i].length; ++j) {
              System.out.println(String.format("Class %d - %f", i, probabilities[i][j]));
            }
          }
    }
 }

以上是硬编码到输入和输出,但我希望能够读取模型并提供一些信息,以便最终用户可以选择输入和输出等.

The above is hardcoded to an input and output but I want to be able to read the model and provide some information so the end-user can select the input and output, etc.

我可以使用 python 命令获取输入和输出:saved_model_cli show --dir ./house_price_median_income --all

I can get the inputs and outputs with the python command: saved_model_cli show --dir ./house_price_median_income --all

我想做的是通过 Java 获取输入和输出,所以我的代码不需要执行 python 脚本来获取它们.我可以通过以下方式进行操作:

What I want to do it get the inputs and outputs via Java so my code doesn't need to execute python script to get them. I can get operations via:

 Graph graph = model.graph();
    Iterator<Operation> itr = graph.operations();
    while (itr.hasNext()) {
        GraphOperation e = (GraphOperation)itr.next();
        System.out.println(e);

这将输入和输出都输出为操作"但是我怎么知道它是输入和/或输出?python 工具使用 SignatureDef 但它似乎根本没有出现在 TensorFlow 2.0 java 东西中.我是否遗漏了一些明显的东西,还是只是在 TensforFlow 2.0 Java 库中遗漏了?

And this outputs both the inputs and outputs as "operations" BUT how do I know that it is an input and\or an output? The python tool uses the SignatureDef but that doesn't seem to appear in the TensorFlow 2.0 java stuff at all. Am I missing something obvious or is it just missing from TensforFlow 2.0 Java library?

注意,我已经使用下面的答案帮助对我的问题进行了排序.这是我的完整代码,以防将来有人会喜欢它.请注意,这是 TF 2.0 并使用下面提到的 SNAPSHOT.我做了一些假设,但它展示了如何提取输入和输出,然后使用它们来运行模型

NOTE, I've sorted my issue with the answer help below. Here is my full bit of code in case somebody would like it in the future. Note this is TF 2.0 and uses the SNAPSHOT mentioned below. I make a few assumptions but it shows how to pull the input and output and then use them to run a model

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.Session.Run;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.GraphOperation;
import org.tensorflow.proto.framework.SignatureDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tensorflow.proto.framework.MetaGraphDef;
import java.util.Map;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.tools.Shape;
import java.nio.FloatBuffer;
import org.tensorflow.tools.buffer.DataBuffers;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.StdArrays;
import org.tensorflow.proto.framework.TensorInfo;

public class v2tensor {
    public static void main(String[] args) {
     try (SavedModelBundle savedModel = SavedModelBundle.load("./house_price_median_income", "serve")) {
        SignatureDef modelInfo = savedModel.metaGraphDef().getSignatureDefMap().get("serving_default");
        TensorInfo input1 = null;
        TensorInfo output1 = null;
        Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
        for(Map.Entry<String, TensorInfo> input : inputs.entrySet()) {
            if (input1 == null) {
                input1 = input.getValue();
                System.out.println(input1.getName());
            }
            System.out.println(input);
        }
        Map<String, TensorInfo> outputs = modelInfo.getOutputsMap();
        for(Map.Entry<String, TensorInfo> output : outputs.entrySet()) {
            if (output1 == null) {
                output1=output.getValue();
            }
            System.out.println(output);
        }

        try (Session session = savedModel.session()) {
            Session.Runner runner = session.runner();
            FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{ { 2.1518f } } );

            try (Tensor<TFloat32> jack = TFloat32.tensorOf(matrix) ) {
                runner.feed(input1.getName(), jack);
                try ( Tensor<TFloat32> rezz = runner.fetch(output1.getName()).run().get(0).expect(TFloat32.DTYPE) ) { 
                    TFloat32 data = rezz.data();
                    data.scalars().forEachIndexed((i, s) -> {
                        System.out.println(s.getFloat());
                    }   );
                }
            }
        }
    } catch (TensorFlowException ex) {
        ex.printStackTrace();   
    }
    }
}

推荐答案

您需要做的是将 SavedModelBundle 元数据读取为 MetaGraphDef,从那里您可以从 SignatureDef 中检索输入和输出名称,就像在 Python 中一样.

What you need to do is to read the SavedModelBundle metadata as a MetaGraphDef, from there you can retrieve input and output names from the SignatureDef, like in Python.

在 TF Java 1.*(即您在示例中使用的客户端)中,proto 定义无法从 tensorflow 工件中开箱即用,您需要添加对 org.tensorflow:proto 的依赖以及对 SavedModelBundle.metaGraphDef()MetaGraphDef proto.

In TF Java 1.* (i.e. the client you are using in your example), the proto definitions are not available out-of-the-box from the tensorflow artifact, you need to add a dependency to org.tensorflow:proto as well and deserialize the result of SavedModelBundle.metaGraphDef() into a MetaGraphDef proto.

在 TF Java 2.* 中(新客户端实际上仅作为来自此处的快照提供),protos 立即出现,因此您只需调用此行即可检索正确的 SignatureDef:

In TF Java 2.* (the new client actually only available as snapshots from here), the protos are present right away so you can simply call this line to retrieve the right SignatureDef:

savedModel.metaGraphDef().signatureDefMap.getValue("serving_default")

这篇关于Tensorflow 2.0 &amp;Java API的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆