使用 Java API 的维度 0 的切片索引 0 越界 [英] slice index 0 of dimension 0 out of bounds using Java API
问题描述
我生成了一个 SavedModel,我可以将其与以下 Python 代码一起使用
I have generated a SavedModel which I can use with the following Python code
import base64
import numpy as np
import tensorflow as tf
fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
filename='test.jpg'
with tf.Session() as sess:
loaded = tf.saved_model.loader.load(sess, ['serve'], 'tools/base64_model/1')
image = fn_load_image(filename)
p = sess.run('predictions:0', feed_dict={"input:0": image})
print(p)
这给了我期望的价值.
在同一型号上使用下面的 Java 代码时
When using the Java code below on the same model
// load the model Bundle
try (SavedModelBundle b = SavedModelBundle.load("tools/base64_model/1",
"serve")) {
// create the session from the Bundle
Session sess = b.session();
// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));
String encodedString = Base64.getUrlEncoder().encodeToString(content);
Tensor t = Tensors.create(encodedString);
// run the model and get the classification
final List<Tensor<?>> result = sess.runner().feed("input", 0, t).fetch("predictions", 0).run();
// print out the result.
System.out.println(result);
}
这应该是等效的,即我将图像的 base64 表示发送给模型,但出现异常
which should be equivalent i.e I send the base64 representation of an image to a model, I am getting an exception
线程main"中的异常 java.lang.IllegalArgumentException: slice维度 0 的索引 0 越界.[[{{节点图/strided_slice}}]]在 org.tensorflow.Session.run(Native Method) 在org.tensorflow.Session.access$100(Session.java:48) 在org.tensorflow.Session$Runner.runHelper(Session.java:326) 在org.tensorflow.Session$Runner.run(Session.java:276) 在com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)
Exception in thread "main" java.lang.IllegalArgumentException: slice index 0 of dimension 0 out of bounds. [[{{node map/strided_slice}}]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:326) at org.tensorflow.Session$Runner.run(Session.java:276) at com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)
张量应该有不同的内容吗?以下是 saved_model_cli 告诉我的关于我的模型的内容.
Should the Tensor have different content? Here is what saved_model_cli is telling me about my model.
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: input:0
The given SavedModel SignatureDef contains the following output(s):
outputs['outputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 2)
name: predictions:0
Method name is: tensorflow/serving/predict
推荐答案
当您提供一个等级为 0 的张量时,您的模型需要一个等级为 1 的输入张量.
You model is expecting an input tensor of rank-1 while you provide a tensor of rank-0.
这一行产生一个可变长度的标量张量(即一个 DT_STRING
).
This line produces a scalar tensor of a variable length (i.e. a DT_STRING
).
Tensor t = Tensors.create(encodedString);
然而,预期张量的秩为 1,正如您在此处的形状 (-1)
所见,这意味着它需要一个包含不同数量元素的向量.
However, the expected tensor is of rank-1, as you can see by the shape (-1)
here, meaning that it expects a vector of a various number of elements.
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: input:0
因此,您的问题可能会通过传递字符串数组来解决.仅当您将字符串作为字节数组传递时,才可以使用 Tensors
工厂,如下所示:
So probably your issue will be fixed by passing an array of strings. This is possible using the Tensors
factories only if you keep you pass your string as an array of array of bytes, like this:
// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));
byte[] encodedBytes = Base64.getUrlEncoder().encode(content);
Tensor t = Tensors.create(new byte[][]{ encodedBytes });
...
这篇关于使用 Java API 的维度 0 的切片索引 0 越界的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!