将Tensorflow模型导入Java [英] Tensorflow model import to Java

查看:91
本文介绍了将Tensorflow模型导入Java的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在尝试导入和使用Java中训练有素的模型(Tensorflow,Python).

I have been trying to import and make use of my trained model (Tensorflow, Python) in Java.

我能够用Python保存模型,但是当我尝试使用Java中的相同模型进行预测时遇到了问题.

I was able to save the model in Python, but encountered problems when I try to make predictions using the same model in Java.

此处,您会看到用于初始化,训练和保存模型的python代码.

Here, you can see the python code for initializing, training, saving the model.

此处,您可以看到用于导入和预测输入值的Java代码.

Here, you can see the Java code for importing and making predictions for input values.

我收到的错误消息是:

Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
     [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:285)
    at org.tensorflow.Session$Runner.run(Session.java:235)
    at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

我相信,问题出在python代码中,但是我找不到它.

I believe, the problem is somewhere in the python code, but I was not able to find it.

推荐答案

Java importGraphDef()函数仅导入计算图(由 tf.train.write_graph 在您的Python代码中),它没有加载(存储在检查点中的)训练后的变量的值,这就是为什么您在抱怨未初始化的变量时会出错的原因.

The Java importGraphDef() function is only importing the computational graph (written by tf.train.write_graph in your Python code), it isn't loading the values of trained variables (stored in the checkpoint), which is why you get an error complaining about uninitialized variables.

另一方面, TensorFlow SavedModel格式包括所有内容有关模型的信息(图形,检查点状态,其他元数据)以及要在Java中使用,请使用

The TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load to create session initialized with the trained variable values.

要从Python以这种格式导出模型,您可能需要看一个相关的问题

To export a model in this format from Python, you might want to take a look at a related question Deploy retrained inception SavedModel to google cloud ml engine

在您的情况下,这相当于Python中的以下内容:

In your case, this should amount to something like the following in Python:

def save_model(session, input_tensor, output_tensor):
  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
  )
  b = saved_model_builder.SavedModelBuilder('/tmp/model')
  b.add_meta_graph_and_variables(session,
                                 [tf.saved_model.tag_constants.SERVING],
                                 signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
  b.save() 

并通过 save_model(session,x,yhat)

然后在Java中使用以下方法加载模型:

And then in Java load the model using:

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
  // b.session().run(...)
}

希望有帮助.

这篇关于将Tensorflow模型导入Java的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆