对 tf.Tensor.set_shape() 的澄清 [英] Clarification on tf.Tensor.set_shape()

查看:29
本文介绍了对 tf.Tensor.set_shape() 的澄清的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一张 478 x 717 x 3 = 1028178 像素的图像,等级为 1.我通过调用 tf.shape 和 tf.rank 来验证它.

I have an image that is 478 x 717 x 3 = 1028178 pixels, with a rank of 1. I verified it by calling tf.shape and tf.rank.

当我调用 image.set_shape([478, 717, 3]) 时,它抛出以下错误.

When I call image.set_shape([478, 717, 3]), it throws the following error.

"Shapes %s and %s must have the same rank" % (self, other)) 
ValueError: Shapes (?,) and (478, 717, 3) must have the same rank

我再次测试,首先转换为 1028178,但错误仍然存​​在.

I tested again by first casting to 1028178, but the error still exists.

ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank

嗯,这是有道理的,因为一个是 1 级,另一个是 3 级.但是,为什么要抛出错误,因为总像素数仍然匹配.

Well, that does make sense because one is of rank 1 and the other is of rank 3. However, why is it necessary to throw an error, as the total number of pixels still match.

我当然可以使用 tf.reshape 并且它有效,但我认为这不是最佳的.

I could of course use tf.reshape and it works, but I think that's not optimal.

如 TensorFlow 常见问题解答中所述

As stated on the TensorFlow FAQ

x.set_shape() 和 x = tf.reshape(x) 有什么区别?

What is the difference between x.set_shape() and x = tf.reshape(x)?

tf.Tensor.set_shape() 方法更新张量的静态形状对象,它通常用于提供额外的形状当不能直接推断时的信息.它不会改变张量的动态形状.

The tf.Tensor.set_shape() method updates the static shape of a Tensor object, and it is typically used to provide additional shape information when this cannot be inferred directly. It does not change the dynamic shape of the tensor.

tf.reshape() 操作创建一个具有不同动态形状的新张量.

The tf.reshape() operation creates a new tensor with a different dynamic shape.

创建新张量涉及内存分配,当涉及更多训练示例时,这可能会更加昂贵.这是设计使然,还是我在这里遗漏了什么?

Creating a new tensor involves memory allocation and that could potentially be more costly when more training examples are involved. Is this by design, or am I missing something here?

推荐答案

据我所知(以及我编写的代码),Tensor.set_shape().我认为误解源于该方法的混淆名称.

As far as I know (and I wrote that code), there isn't a bug in Tensor.set_shape(). I think the misunderstanding stems from the confusing name of that method.

详细说明您引用的常见问题解答, Tensor.set_shape() 是一个纯 Python 函数,它改进给定 tf.Tensor 对象的形状信息.通过改进",我的意思是使更具体".

To elaborate on the FAQ entry you quoted, Tensor.set_shape() is a pure-Python function that improves the shape information for a given tf.Tensor object. By "improves", I mean "makes more specific".

因此,当你有一个形状为 (?,)Tensor 对象 t 时,它是一个未知长度的一维张量.你可以调用t.set_shape((1028178,)),然后当你调用t就会有形状(1028178,)t.get_shape().这不会影响底层存储,或者后端的任何东西:它只是意味着使用 t 的后续形状推断可以依赖于它是长度为 1028178 的向量的断言.

Therefore, when you have a Tensor object t with shape (?,), that is a one-dimensional tensor of unknown length. You can call t.set_shape((1028178,)), and then t will have shape (1028178,) when you call t.get_shape(). This doesn't affect the underlying storage, or indeed anything on the backend: it merely means that subsequent shape inference using t can rely on the assertion that it is a vector of length 1028178.

如果 t 具有形状 (?,),则调用 t.set_shape((478, 717, 3)) 将失败,因为TensorFlow已经知道t是一个向量,所以它不能有形状(478, 717, 3).如果您想根据 t 的内容制作具有该形状的新张量,您可以使用 reshape_t = tf.reshape(t, (478, 717, 3)).这会在 Python 中创建一个新的 tf.Tensor 对象;实际的实施.recode(t>) 使用张量缓冲区的浅拷贝来做到这一点,因此在实践中成本低廉.

If t has shape (?,), a call to t.set_shape((478, 717, 3)) will fail, because TensorFlow already knows that t is a vector, so it cannot have shape (478, 717, 3). If you want to make a new Tensor with that shape from the contents of t, you can use reshaped_t = tf.reshape(t, (478, 717, 3)). This creates a new tf.Tensor object in Python; the actual implementation of tf.reshape() does this using a shallow copy of the tensor buffer, so it is inexpensive in practice.

一个类比是 Tensor.set_shape() 就像 Java 等面向对象语言中的运行时转换.例如,如果您有一个指向 Object 的指针,但知道它实际上是一个 String,则可以执行强制转换 (String) obj 以便将 obj 传递给需要 String 参数的方法.但是,如果您有一个 String s 并尝试将其转换为 java.util.Vector,编译器会给您一个错误,因为这两种类型是不相关的.

One analogy is that Tensor.set_shape() is like a run-time cast in an object-oriented language like Java. For example, if you have a pointer to an Object but know that, in fact, it is a String, you might do the cast (String) obj in order to pass obj to a method that expects a String argument. However, if you have a String s and try to cast it to a java.util.Vector, the compiler will give you an error, because these two types are unrelated.

这篇关于对 tf.Tensor.set_shape() 的澄清的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆