RuntimeError:只能计算浮动类型的平均值。取而代之的是字节。平均值+ = images_data.mean(2).sum(0) [英] RuntimeError: Can only calculate the mean of floating types. Got Byte instead. for mean += images_data.mean(2).sum(0)

查看:956
本文介绍了RuntimeError:只能计算浮动类型的平均值。取而代之的是字节。平均值+ = images_data.mean(2).sum(0)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有以下代码段:

 #设备配置
device = torch.device('cuda:0'如果torch.cuda.is_available()else'cpu')
种子= 42
np.random.seed(种子)
torch.manual_seed(种子)

#将数据集拆分为验证和测试集
len_valid_set = int(0.1 * len(dataset))
len_train_set = len(dataset)-len_valid_set

print(训练集为{}。format(len_train_set))
print(测试集的长度为{}。format(len_valid_set))

train_dataset,valid_dataset = torch.utils.data.random_split(dataset,[len_train_set,len_valid_set])

#随机整理和批处理数据集
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 8, shuffle = True,num_workers = 4)
test_loader = torch.utils.data.DataLoader(valid_dataset,batch_size = 8,shuffle = True,num_workers = 4)

print( LOADERS; ,
len(dataloader),
len(train_loa der),
len(test_loader))

Train set的长度为720


测试集的长度为80


加载程序267 90 10

 平均值= 0.0 
std = 0.0
nb_samples = 0.0
对于train_loader中的数据:
张图像,地标=数据[图像],数据[地标]
batch_samples = images.size(0)

images_data = images.view(batch_samples,images.size(1),-1)
均值+ = images_data.mean(2).sum( 0)
std + = images_data.std(2).sum(0)
nb_samples + = batch_samples

平均值/ = nb_samples
std / = nb_samples

我收到此错误:

 - -------------------------------------------------- ------------------------ 
RuntimeError跟踪(最近一次调用最近)
< ipython-input-23-9e47ddfeff5e>在< module>中
7
8 images_data = images.view(batch_samples,images.size(1),-1)
----> 9平均值+ = images_data.mean(2).sum(0)
10 std + = images_data.std(2).sum(0)
11 nb_samples + = batch_samples

RuntimeError:只能计算浮点类型的平均值。取而代之的是字节。

固定代码来自

解决方案

错误说,您的 images_data 是ByteTensor,即dtype uint8 。火炬拒绝计算整数的平均值。您可以使用以下命令将数据转换为 float

 (images_data * 1.0).mean( 2)

 火炬。 Tensor.float(images_data).mean(2)


I have the following pieces of code:

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

# split the dataset into validation and test sets
len_valid_set = int(0.1*len(dataset))
len_train_set = len(dataset) - len_valid_set

print("The length of Train set is {}".format(len_train_set))
print("The length of Test set is {}".format(len_valid_set))

train_dataset , valid_dataset,  = torch.utils.data.random_split(dataset , [len_train_set, len_valid_set])

# shuffle and batch the datasets
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=8, shuffle=True, num_workers=4)

print("LOADERS",
    len(dataloader),
    len(train_loader),
    len(test_loader))

The length of Train set is 720

The length of Test set is 80

LOADERS 267 90 10

mean = 0.0
std = 0.0
nb_samples = 0.0
for data in train_loader:
    images, landmarks = data["image"], data["landmarks"]
    batch_samples = images.size(0)

    images_data = images.view(batch_samples, images.size(1), -1)
    mean += images_data.mean(2).sum(0)
    std += images_data.std(2).sum(0)
    nb_samples += batch_samples

mean /= nb_samples
std /= nb_samples

And I get this error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-23-9e47ddfeff5e> in <module>
      7 
      8     images_data = images.view(batch_samples, images.size(1), -1)
----> 9     mean += images_data.mean(2).sum(0)
     10     std += images_data.std(2).sum(0)
     11     nb_samples += batch_samples

RuntimeError: Can only calculate the mean of floating types. Got Byte instead.

The fixed code is taken from https://stackoverflow.com/a/64349380/2414957 it worked for dataloader but not train_loader

Also, these are the results of

print(type(images_data))
print(images_data)

We have:

<class 'torch.Tensor'>
tensor([[[74, 74, 74,  ..., 63, 63, 63],
         [73, 73, 73,  ..., 61, 61, 61],
         [75, 75, 75,  ..., 61, 61, 61],
         ...,
         [74, 74, 74,  ..., 38, 38, 38],
         [75, 75, 75,  ..., 39, 39, 39],
         [72, 72, 72,  ..., 38, 38, 38]],

        [[75, 75, 75,  ..., 65, 65, 65],
         [75, 75, 75,  ..., 62, 62, 62],
         [75, 75, 75,  ..., 63, 63, 63],
         ...,
         [71, 71, 71,  ..., 39, 39, 39],
         [74, 74, 74,  ..., 38, 38, 38],
         [73, 73, 73,  ..., 37, 37, 37]],

        [[72, 72, 72,  ..., 62, 62, 62],
         [74, 74, 74,  ..., 63, 63, 63],
         [75, 75, 75,  ..., 61, 61, 61],
         ...,
         [74, 74, 74,  ..., 38, 38, 38],
         [74, 74, 74,  ..., 39, 39, 39],
         [73, 73, 73,  ..., 37, 37, 37]],

        ...,

        [[75, 75, 75,  ..., 63, 63, 63],
         [73, 73, 73,  ..., 63, 63, 63],
         [74, 74, 74,  ..., 62, 62, 62],
         ...,
         [74, 74, 74,  ..., 38, 38, 38],
         [73, 73, 73,  ..., 39, 39, 39],
         [73, 73, 73,  ..., 37, 37, 37]],

        [[73, 73, 73,  ..., 62, 62, 62],
         [75, 75, 75,  ..., 62, 62, 62],
         [74, 74, 74,  ..., 63, 63, 63],
         ...,
         [73, 73, 73,  ..., 39, 39, 39],
         [74, 74, 74,  ..., 38, 38, 38],
         [74, 74, 74,  ..., 38, 38, 38]],

        [[74, 74, 74,  ..., 62, 62, 62],
         [74, 74, 74,  ..., 63, 63, 63],
         [74, 74, 74,  ..., 62, 62, 62],
         ...,
         [74, 74, 74,  ..., 38, 38, 38],
         [73, 73, 73,  ..., 38, 38, 38],
         [72, 72, 72,  ..., 36, 36, 36]]], dtype=torch.uint8)

When I tried

images_data = images_data.float()
mean += images_data.mean(2).sum(0)

I didn't get a tensor for 3 values for mean and 3 values for std like I expected but got a very large tensor (each torch.Size([600]))

解决方案

As the error says, your images_data is a ByteTensor, i.e. has dtype uint8. Torch refuses to compute the mean of integers. You can convert the data to float with:

(images_data * 1.0).mean(2)

Or

torch.Tensor.float(images_data).mean(2)

这篇关于RuntimeError:只能计算浮动类型的平均值。取而代之的是字节。平均值+ = images_data.mean(2).sum(0)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆