Pytorch:如何获得图中的所有张量 [英] Pytorch: how to get all the tensors in a graph
问题描述
我想访问图的所有张量实例.例如,我可以检查张量是否已分离或我可以检查其大小.可以在 tensorflow中完成.
我不想想要图形的可视化.
您可以在运行时访问整个计算图.为此,您可以使用钩子.这些是插入到 nn.Module
上的函数,用于推理和反向传播.
推断时,您可以使用 register_backward_hook
(注意:在版本 1.8.0 上,不推荐使用此功能,而推荐使用 pdb
:
def向后挂钩(module,grad_input,grad_output):pdb.set_trace()
对于模型的参数,您可以通过调用tensorflow.
I don't want visualization of the graph.
You can get access to the entirety of the computation graph at runtime. To do so, you can use hooks. These are functions plugged onto nn.Module
s both for inference and when backpropagating.
At inference you can plug a hook with register_forward_hook
. For backpropagation you can use register_backward_hook
(Note: on version 1.8.0 this function will be deprecated in favor of register_full_backward_hook
).
With these two functions, you will basically have access to any tensor on the computation graph. It's entirely up to you whether you want to print all tensors, print the shapes, or even insert breakpoints to investigate.
Here is a possible implementation:
def forward_hook(module, input, output):
# ...
Argument input
is passed by PyTorch as a tuple and will contain all arguments passed to the forward function of the hooked module.
def backward_hook(module, grad_input, grad_output):
# ...
For the backward hook, both grad_input
and grad_output
will be tuples and will have varying shapes depending on your model's layers.
Then you can hook these callbacks on any existing nn.Module
. For example, you could loop over all child modules from your model:
for module in model.children():
module.register_forward_hook(forward_hook)
module.register_backward_hook(backward_hook)
To get the names of the modules, you can wrap the hook to enclose the name and loop on your model's named_modules
:
def forward_hook(name):
def hook(module, x, y):
print('%s: %s -> %s' % (name, list(x[0].size()), list(y.size())))
return hook
for name, module in model.named_children():
module.register_forward_hook(forward_hook(name))
Which could print the following on inference:
fc1: [1, 100] -> [1, 10]
fc2: [1, 10] -> [1, 5]
fc3: [1, 5] -> [1, 1]
As I said it's a bit more complicated on the backward pass. I can only recommend you explore and experiment with pdb
:
def backward_hook(module, grad_input, grad_output):
pdb.set_trace()
As for the model's parameter, you can easily access the parameters for a given module in both hooks by calling module.parameters
. This will return a generator.
I can only wish you good luck exploring your model!
这篇关于Pytorch:如何获得图中的所有张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!