ValueError:浮点图像RGB值必须在0..1范围内。在使用matplotlib时 [英] ValueError: Floating point image RGB values must be in the 0..1 range. while using matplotlib
问题描述
我想可视化神经网络层的权重。我正在使用pytorch。
I want to visualize weights of the layer of a neural network. I'm using pytorch.
import torch
import torchvision.models as models
from matplotlib import pyplot as plt
def plot_kernels(tensor, num_cols=6):
if not tensor.ndim==4:
raise Exception("assumes a 4D tensor")
if not tensor.shape[-1]==3:
raise Exception("last dim needs to be 3 to plot")
num_kernels = tensor.shape[0]
num_rows = 1+ num_kernels // num_cols
fig = plt.figure(figsize=(num_cols,num_rows))
for i in range(tensor.shape[0]):
ax1 = fig.add_subplot(num_rows,num_cols,i+1)
ax1.imshow(tensor[i])
ax1.axis('off')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
vgg = models.vgg16(pretrained=True)
mm = vgg.double()
filters = mm.modules
body_model = [i for i in mm.children()][0]
layer1 = body_model[0]
tensor = layer1.weight.data.numpy()
plot_kernels(tensor)
上面的错误给出了此错误 ValueError:浮点图像RGB值必须在0..1范围内。
The above gives this error ValueError: Floating point image RGB values must be in the 0..1 range.
我的问题是我应该归一化并采用权重的绝对值来克服此错误,还是还有其他方法?
如果我归一化并使用绝对值,我认为图形的含义会改变。
My question is should I normalize and take absolute value of the weights to overcome this error or is there anyother way ? If I normalize and use absolute value I think the meaning of the graphs change.
[[[[ 0.02240197 -1.22057354 -0.55051649]
[-0.50310904 0.00891289 0.15427093]
[ 0.42360783 -0.23392732 -0.56789106]]
[[ 1.12248898 0.99013627 1.6526649 ]
[ 1.09936976 2.39608836 1.83921957]
[ 1.64557672 1.4093554 0.76332706]]
[[ 0.26969245 -1.2997849 -0.64577204]
[-1.88377869 -2.0100112 -1.43068039]
[-0.44531786 -1.67845118 -1.33723605]]]
[[[ 0.71286005 1.45265901 0.64986968]
[ 0.75984162 1.8061738 1.06934202]
[-0.08650422 0.83452386 -0.04468433]]
[[-1.36591709 -2.01630116 -1.54488969]
[-1.46221244 -2.5365622 -1.91758668]
[-0.88827479 -1.59151018 -1.47308767]]
[[ 0.93600738 0.98174071 1.12213969]
[ 1.03908169 0.83749604 1.09565806]
[ 0.71188802 0.85773659 0.86840987]]]
[[[-0.48592842 0.2971966 1.3365227 ]
[ 0.47920835 -0.18186836 0.59673625]
[-0.81358945 1.23862112 0.13635623]]
[[-0.75361633 -1.074965 0.70477796]
[ 1.24439156 -1.53563368 -1.03012812]
[ 0.97597247 0.83084011 -1.81764793]]
[[-0.80762428 -0.62829626 1.37428832]
[ 1.01448071 -0.81775147 -0.41943246]
[ 1.02848887 1.39178836 -1.36779451]]]
...,
[[[ 1.28134537 -0.00482408 0.71610934]
[ 0.95264435 -0.09291686 -0.28001019]
[ 1.34494913 0.64477581 0.96984017]]
[[-0.34442815 -1.40002513 1.66856039]
[-2.21281362 -3.24513769 -1.17751861]
[-0.93520379 -1.99811196 0.72937071]]
[[ 0.63388056 -0.17022935 2.06905985]
[-0.7285465 -1.24722099 0.30488953]
[ 0.24900314 -0.19559766 1.45432627]]]
[[[-0.80684513 2.1764245 -0.73765725]
[-1.35886598 1.71875226 -1.73327696]
[-0.75233924 2.14700699 -0.71064663]]
[[-0.79627383 2.21598244 -0.57396138]
[-1.81044972 1.88310981 -1.63758397]
[-0.6589964 2.013237 -0.48532376]]
[[-0.3710472 1.4949851 -0.30245575]
[-1.25448656 1.20453358 -1.29454732]
[-0.56755757 1.30994892 -0.39370224]]]
[[[-0.67361742 -3.69201088 -1.23768616]
[ 3.12674141 1.70414758 -1.76272404]
[-0.22565465 1.66484773 1.38172317]]
[[ 0.28095332 -2.03035069 0.69989491]
[ 1.97936332 1.76992691 -1.09842575]
[-2.22433758 0.52577412 0.18292744]]
[[ 0.48471382 -1.1984663 1.57565165]
[ 1.09911084 1.31910467 -0.51982772]
[-2.76202297 -0.47073677 0.03936549]]]]
推荐答案
听起来好像您已经知道自己的值不在该范围内。是的,您必须将它们重新缩放到0.0-1.0的范围。我建议您要保留负面与正面的可见性,但让0.5为新的中立点。缩放,以使当前的0.0值映射到0.5,而最大的值(最大的幅度)映射为0.0(如果为负)或1.0(如果为正)。
It sounds as if you already know your values are not in that range. Yes, you must re-scale them to the range 0.0 - 1.0. I suggest that you want to retain visibility of negative vs positive, but that you let 0.5 be your new "neutral" point. Scale such that current 0.0 values map to 0.5, and your most extreme value (largest magnitude) scale to 0.0 (if negative) or 1.0 (if positive).
感谢矢量。您的值似乎在-2.25至+2.0的范围内。我建议重新缩放 new =(1 /(2 * 2.25))*旧+ 0.5
Thanks for the vectors. It looks like your values are in the range -2.25 to +2.0. I suggest a rescaling new = (1/(2*2.25)) * old + 0.5
这篇关于ValueError:浮点图像RGB值必须在0..1范围内。在使用matplotlib时的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!