torch.flatten() 和 nn.Flatten() 的区别 [英] Difference between torch.flatten() and nn.Flatten()

查看:52
本文介绍了torch.flatten() 和 nn.Flatten() 的区别的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

torch.flatten()torch.nn.Flatten() 有什么区别?

推荐答案

Flattening 在 PyTorch 中以三种形式提供

Flattening is available in three forms in PyTorch

作为函数(函数形式)torch.flatten 应用为:torch.flatten(x).

作为一个模块(nn.Module)nn.Flatten().通常用于模型定义.

As a module (layer nn.Module) nn.Flatten(). Generally used in a model definition.

所有三个都是相同的并且共享相同的实现,唯一的区别是 nn.Flatten 默认将 start_dim 设置为 1 以避免展平第一个轴(通常是批处理轴).而另外两个从 axis=0 展平到 axis=-1 - ie 整个张量 - 如果没有给出参数.

All three are identical and share the same implementation, the only difference being nn.Flatten has start_dim set to 1 by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0 to axis=-1 - i.e. the entire tensor - if no arguments are given.

这篇关于torch.flatten() 和 nn.Flatten() 的区别的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆