torch.flatten() 和 nn.Flatten() 的区别 [英] Difference between torch.flatten() and nn.Flatten()
问题描述
torch.flatten()
和 torch.nn.Flatten()
有什么区别?
推荐答案
Flattening 在 PyTorch 中以三种形式提供
Flattening is available in three forms in PyTorch
作为张量方法(oop 风格)
torch.Tensor.flatten
直接应用于张量:x.flatten()
.
作为函数(函数形式)torch.flatten
应用为:torch.flatten(x)
.
作为一个模块(层 nn.Module
)nn.Flatten()
.通常用于模型定义.
As a module (layer nn.Module
) nn.Flatten()
. Generally used in a model definition.
所有三个都是相同的并且共享相同的实现,唯一的区别是 nn.Flatten
默认将 start_dim
设置为 1
以避免展平第一个轴(通常是批处理轴).而另外两个从 axis=0
展平到 axis=-1
- ie 整个张量 - 如果没有给出参数.
All three are identical and share the same implementation, the only difference being nn.Flatten
has start_dim
set to 1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0
to axis=-1
- i.e. the entire tensor - if no arguments are given.
这篇关于torch.flatten() 和 nn.Flatten() 的区别的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!