调整 PyTorch 张量 [英] Resize PyTorch Tensor

查看:34
本文介绍了调整 PyTorch 张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前正在使用 tensor.resize() 函数将张量调整为新形状 t = t.resize(1, 2, 3).

I am currently using the tensor.resize() function to resize a tensor to a new shape t = t.resize(1, 2, 3).

这给了我一个弃用警告:

This gives me a deprecation warning:

不推荐使用非就地调整大小

non-inplace resize is deprecated

因此,我想切换到 tensor.resize_() 函数,这似乎是适当的就地替换.然而,这给我留下了

Hence, I wanted to switch over to the tensor.resize_() function, which seems to be the appropriate in-place replacement. However, this leaves me with an

无法调整需要 grad 的变量的大小

cannot resize variables that require grad

错误.我可以退回到

from torch.autograd._functions import Resize
Resize.apply(t, (1, 2, 3))

这是 tensor.resize() 为了避免弃用警告所做的.这对我来说似乎不是一个合适的解决方案,而是一个黑客.在这种情况下,我如何正确使用 tensor.resize_() ?

which is what tensor.resize() does in order to avoid the deprecation warning. This doesn't seem like an appropriate solution but rather a hack to me. How do I correctly make use of tensor.resize_() in this case?

推荐答案

您可以改为选择 tensor.reshape(new_shape)torch.reshape(tensor, new_shape) 如:

You can instead choose to go with tensor.reshape(new_shape) or torch.reshape(tensor, new_shape) as in:

# a `Variable` tensor
In [15]: ten = torch.randn(6, requires_grad=True)

# this would throw RuntimeError error
In [16]: ten.resize_(2, 3)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-16-094491c46baa> in <module>()
----> 1 ten.resize_(2, 3)

RuntimeError: cannot resize variables that require grad

<小时>

上面的RuntimeError可以通过使用tensor.reshape(new_shape)

In [17]: ten.reshape(2, 3)
Out[17]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

# yet another way of changing tensor shape
In [18]: torch.reshape(ten, (2, 3))
Out[18]: 
tensor([[-0.2185, -0.6335, -0.0041],
        [-1.0147, -1.6359,  0.6965]])

这篇关于调整 PyTorch 张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆