Pytorch:无法在需要 grad 的变量上调用 numpy().使用 var.detach().numpy() 代替 [英] Pytorch: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead
问题描述
我的代码中有一个错误,无论我尝试哪种方式都没有得到修复.
I have an error in my code which is not getting fixed any which way I try.
错误很简单,我返回一个值:
The Error is simple, I return a value:
torch.exp(-LL_total/T_total)
然后在管道中得到错误:
and get the error later in the pipeline:
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
诸如 cpu().detach().numpy()
之类的解决方案会给出相同的错误.
Solutions such as cpu().detach().numpy()
give the same error.
我该如何解决?谢谢.
推荐答案
错误重现
import torch
tensor1 = torch.tensor([1.0,2.0],requires_grad=True)
print(tensor1)
print(type(tensor1))
tensor1 = tensor1.numpy()
print(tensor1)
print(type(tensor1))
这会导致 tensor1 = tensor1.numpy()
行出现完全相同的错误:
which leads to the exact same error for the line tensor1 = tensor1.numpy()
:
tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
Traceback (most recent call last):
File "/home/badScript.py", line 8, in <module>
tensor1 = tensor1.numpy()
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.
Process finished with exit code 1
通用解决方案
这是在您的错误消息中向您建议的,只需将 var
替换为您的变量名称
import torch
tensor1 = torch.tensor([1.0,2.0],requires_grad=True)
print(tensor1)
print(type(tensor1))
tensor1 = tensor1.detach().numpy()
print(tensor1)
print(type(tensor1))
按预期返回
tensor([1., 2.], requires_grad=True)
<class 'torch.Tensor'>
[1. 2.]
<class 'numpy.ndarray'>
Process finished with exit code 0
一些解释
您需要将张量转换为另一个除了实际值定义之外不需要梯度的张量.这个其他张量可以转换为一个 numpy 数组.参见这个discuss.pytorch帖子一>.(我认为,更准确地说,需要这样做才能从其 pytorch Variable
包装器中获取实际张量,参见 另一个讨论.pytorch 帖子).
Some explanation
You need to convert your tensor to another tensor that isn't requiring a gradient in addition to its actual value definition. This other tensor can be converted to a numpy array. Cf. this discuss.pytorch post. (I think, more precisely, that one needs to do that in order to get the actual tensor out of its pytorch Variable
wrapper, cf. this other discuss.pytorch post).
这篇关于Pytorch:无法在需要 grad 的变量上调用 numpy().使用 var.detach().numpy() 代替的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!