pytorch 如何从张量中删除 cuda() [英] pytorch how to remove cuda() from tensor

查看:91
本文介绍了pytorch 如何从张量中删除 cuda()的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我收到 TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor).

如何将 torch.cuda.FloatTensor 转换为 torch.LongTensor?

  Traceback (most recent call last):
  File "train_v2.py", line 110, in <module>
    main()
  File "train_v2.py", line 81, in main
    model.update(batch)
  File "/home/Desktop/squad_vteam/src/model.py", line 131, in update
    loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y)
  File "/home/Desktop/squad_vteam/src/model.py", line 94, in adversarial_loss
    adv_embedding = torch.LongTensor(adv_embedding)
TypeError: expected torch.LongTensor (got torch.cuda.FloatTensor)

推荐答案

您有一个浮点张量 f 并且想要将其转换为 long,您执行 long_tensor = f.long()

You have a float tensor f and want to convert it to long, you do long_tensor = f.long()

您有 cuda 张量,即数据在 gpu 上并且想要将其移动到 cpu,您可以执行 cuda_tensor.cpu().

You have cuda tensor i.e data is on gpu and want to move it to cpu you can do cuda_tensor.cpu().

所以要将 torch.cuda.Float 张量 A 转换为 torch.long,请执行 A.long().cpu()

So to convert a torch.cuda.Float tensor A to torch.long do A.long().cpu()

这篇关于pytorch 如何从张量中删除 cuda()的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆