如何使用 PyTorch 计算偏导数? [英] How to use PyTorch to calculate partial derivatives?

查看:22
本文介绍了如何使用 PyTorch 计算偏导数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用 PyTorch 来获取输出和输入之间的偏导数.假设我有一个函数 Y = 5*x1^4 + 3*x2^3 + 7*x1^2 + 9*x2 - 5,我训练一个网络来代替这个函数,然后我使用 autograd 计算 dYdx1, dYdx2:

I want to use PyTorch to get the partial derivatives between output and input. Suppose I have a function Y = 5*x1^4 + 3*x2^3 + 7*x1^2 + 9*x2 - 5, and I train a network to replace this function, then I use autograd to calculate dYdx1, dYdx2:

net = torch.load('net_723.pkl')
x = torch.tensor([[1,-1]],requires_grad=True).type(torch.FloatTensor)
y = net(x)
grad_c = torch.autograd.grad(y,x,create_graph=True,retain_graph=True)[0] 

然后我得到一个错误的导数:

Then I get a wrong derivative as:

>>>tensor([[ 7.5583, -5.3173]])

但是当我使用函数来计算时,我得到了正确的答案:

but when I use function to calculate, I get the right answer:

Y = 5*x[0,0]**4 + 3*x[0,1]**3 + 7*x[0,0]**2 + 9*x[0,1] - 5
grad_c = torch.autograd.grad(Y,x,create_graph=True,retain_graph=True)[0]
>>>tensor([[ 34.,  18.]])

为什么会发生这种情况?

Why does this happen?

推荐答案

神经网络是一个通用函数近似器.这意味着,对于足够的计算资源、训练时间、节点等,您可以近似 任何函数.
如果没有关于您在第一个示例中如何训练网络的任何进一步信息,我会怀疑您的网络根本不适合底层功能,这意味着您的网络的内部表示实际上模拟了 强>不同的功能!

A neural network is a universal function approximator. What that means is, that, for enough computational resources, training time, nodes, etc., you can approximate any function.
Without any further information on how you trained your network in the first example, I would suspect that your network simply does not fit properly to the underlying function, meaning that the internal representation of your network actually models a different function!

对于第二个代码片段,自动微分确实为您提供了精确偏导数.它是通过不同的方法实现的,请参阅我关于 SO 的另一个答案,专门针对 AutoDiff/Autograd 的话题.

For the second code snippet, autmatic differentiation does give you the exact partial derivative. It does so via a different method, see another one of my answers on SO, on the topic of AutoDiff/Autograd specifically.

这篇关于如何使用 PyTorch 计算偏导数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆