使用 Java API 的维度 0 的切片索引 0 越界 [英] slice index 0 of dimension 0 out of bounds using Java API

查看:18
本文介绍了使用 Java API 的维度 0 的切片索引 0 越界的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我生成了一个 SavedModel,我可以将其与以下 Python 代码一起使用

I have generated a SavedModel which I can use with the following Python code

import base64
import numpy as np
import tensorflow as tf
​
​
fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
filename='test.jpg'
with tf.Session() as sess:
    loaded = tf.saved_model.loader.load(sess, ['serve'], 'tools/base64_model/1')
    image = fn_load_image(filename)
    p = sess.run('predictions:0', feed_dict={"input:0": image})
    print(p)

这给了我期望的价值.

在同一型号上使用下面的 Java 代码时

When using the Java code below on the same model

    // load the model Bundle
    try (SavedModelBundle b = SavedModelBundle.load("tools/base64_model/1",
            "serve")) {

        // create the session from the Bundle
        Session sess = b.session();

        // base64 representation of JPG
        byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));

        String encodedString = Base64.getUrlEncoder().encodeToString(content);

        Tensor t = Tensors.create(encodedString);

        // run the model and get the classification
        final List<Tensor<?>> result = sess.runner().feed("input", 0, t).fetch("predictions", 0).run();

        // print out the result.
        System.out.println(result);
    }

这应该是等效的,即我将图像的 base64 表示发送给模型,但出现异常

which should be equivalent i.e I send the base64 representation of an image to a model, I am getting an exception

线程main"中的异常 java.lang.IllegalArgumentException: slice维度 0 的索引 0 越界.[[{{节点图/strided_slice}}]]在 org.tensorflow.Session.run(Native Method) 在org.tensorflow.Session.access$100(Session.java:48) 在org.tensorflow.Session$Runner.runHelper(Session.java:326) 在org.tensorflow.Session$Runner.run(Session.java:276) 在com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)

Exception in thread "main" java.lang.IllegalArgumentException: slice index 0 of dimension 0 out of bounds. [[{{node map/strided_slice}}]] at org.tensorflow.Session.run(Native Method) at org.tensorflow.Session.access$100(Session.java:48) at org.tensorflow.Session$Runner.runHelper(Session.java:326) at org.tensorflow.Session$Runner.run(Session.java:276) at com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)

张量应该有不同的内容吗?以下是 saved_model_cli 告诉我的关于我的模型的内容.

Should the Tensor have different content? Here is what saved_model_cli is telling me about my model.

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['outputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: predictions:0
  Method name is: tensorflow/serving/predict

推荐答案

当您提供一个等级为 0 的张量时,您的模型需要一个等级为 1 的输入张量.

You model is expecting an input tensor of rank-1 while you provide a tensor of rank-0.

这一行产生一个可变长度的标量张量(即一个 DT_STRING).

This line produces a scalar tensor of a variable length (i.e. a DT_STRING).

Tensor t = Tensors.create(encodedString);

然而,预期张量的秩为 1,正如您在此处的形状 (-1) 所见,这意味着它需要一个包含不同数量元素的向量.

However, the expected tensor is of rank-1, as you can see by the shape (-1) here, meaning that it expects a vector of a various number of elements.

The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: input:0

因此,您的问题可能会通过传递字符串数组来解决.仅当您将字符串作为字节数组传递时,才可以使用 Tensors 工厂,如下所示:

So probably your issue will be fixed by passing an array of strings. This is possible using the Tensors factories only if you keep you pass your string as an array of array of bytes, like this:

// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));
byte[] encodedBytes = Base64.getUrlEncoder().encode(content);
Tensor t = Tensors.create(new byte[][]{ encodedBytes });
...

这篇关于使用 Java API 的维度 0 的切片索引 0 越界的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆