如何从 libtorch 输出中删除乘数并显示最终结果? [英] How to remove the multiplier from the libtorch output and display the final result?

查看:36
本文介绍了如何从 libtorch 输出中删除乘数并显示最终结果?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我尝试在屏幕上显示/打印一些张量时,我面临类似以下的情况,其中似乎 libtorch 显示带有乘数的张量(即 0.01*),而不是获得最终结果以及你可以在下面看到的喜欢):

When I try to display/print some tensors to the screen, I face something like the following where instead of getting the final result, it seems libtorch displays the tensor with a multiplier (i.e. 0.01* and the likes as you can see below) :

offsets.shape: [1, 4, 46, 85]
probs.shape: [46, 85]
offsets: (1,1,.,.) =
 0.01 *
  0.1006  1.2322
  -2.9587 -2.2280

(1,2,.,.) =
 0.01 *
  1.3772  1.3971
  -1.2813 -0.8563

(1,3,.,.) =
 0.01 *
  6.2367  9.2561
   3.5719  5.4744

(1,4,.,.) =
  0.2901  0.2963
  0.2618  0.2771
[ CPUFloatType{1,4,2,2} ]
probs: 0.0001 *
 1.4593  1.0351
  6.6782  4.9104
[ CPUFloatType{2,2} ]

如何禁用此行为并获得最终输出?我试图将其显式转换为浮点数,希望这将导致存储/显示最终输出,但这也不起作用.

How can I disable this behavior and get the final output? I tried to explicitly convert this into float hoping this will lead to the finalized output to be stored/displayed but that doesn't work either.

推荐答案

基于 libtorch 的输出张量的源代码,经过搜索;*"存储库中的字符串,结果是这个漂亮的打印"在 aten/src/ATen/core/Formatting.cpp 翻译单元中完成.比例和星号在此处添加:

Basing on libtorch's source code for outputting the tensors, after searching for " *" string within the repository, it turns out that this "pretty-print" is done in aten/src/ATen/core/Formatting.cpp translation unit. The scale and asterisk is prepended here:

static void printScale(std::ostream & stream, double scale) {
  FormatGuard guard(stream);
  stream << defaultfloat << scale << " *" << std::endl;
}

然后张量的所有坐标都除以scale:

And later on all coordinates of the Tensor are divided by the scale:

if(scale != 1) {
  printScale(stream, scale);
}
double* tensor_p = tensor.data_ptr<double>();
for(int64_t i = 0; i < tensor.size(0); i++) {
  stream << std::setw(sz) << tensor_p[i]/scale << std::endl;
}

基于这个翻译单元,这根本是不可配置的.

Basing on this translation unit, this is not configurable at all.

我猜你有两个选择:

  1. 调整函数并最小化编辑现有函数以满足您的要求.
  2. 在 Formatting.cpp 中删除(或添加 #ifdef)张量的 << 运算符重载并提供您自己的实现.但是,在构建 libtorch 时,您必须将其链接到包含该方法实现的目标.
  1. Tweak around with the functions and edit already existing functions minimally to meet your requirements.
  2. Remove (or add #ifdef) the << operator overload for Tensor in Formatting.cpp and provide your own implementation. When building libtorch, however, you'd have to link it to your target containing the method's implementation.

但是,这两个选项都需要您更改第 3 方代码,我认为这很糟糕.

Both options, however, require your to change 3rd party code, which is quite bad, I believe.

这篇关于如何从 libtorch 输出中删除乘数并显示最终结果?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆