将pyplot图形转换为数组 [英] Converting pyplot figure to array

查看:115
本文介绍了将pyplot图形转换为数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将使用pyplot绘制的图形转换为数组,但是在执行此操作之前,我想消除图形外部的任何空间.在我目前的方法中,我将图形保存到一个临时文件中(使用 plt.savefig 的功能来消除图形外的任何空间,即使用 bbox_inches ='tight' pad_inches = 0 ),然后从临时文件加载图像.这是MWE:

I am trying to convert a figure drawn using pyplot to an array, but I would like to eliminate any space outside of the plot before doing so. In my current approach, I am saving the figure to a temporary file (using the functionality of plt.savefig to eliminate any space outside the plot, i.e. using bbox_inches='tight' and pad_inches = 0), and then loading the image from the temporary file. Here's an MWE:

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot([0,1], color='black', linewidth=4)
plt.xlim([0,1])
plt.ylim([0,1])
ax.set_aspect('equal', adjustable='box')
plt.axis('off')
plt.savefig('./tmp.png', bbox_inches='tight', pad_inches = 0)
plt.close()
img_size = 128
img = Image.open('./tmp.png')
X = np.array(img)

这种方法是不受欢迎的,因为需要写入和读取文件的时间.我知道以下用于直接从像素缓冲区转到数组的方法:

This approach is undesirable, because of the time required to write the file and read it. I'm aware of the following method for going directly from the pixel buffer to an array:

from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvas
import numpy as np

fig, ax = plt.subplots()
canvas = FigureCanvas(fig)
ax.plot([0,1], color='black', linewidth=4)
plt.xlim([0,1])
plt.ylim([0,1])
ax.set_aspect('equal', adjustable='box')
plt.axis('off')
canvas.draw()
X = np.array(canvas.renderer.buffer_rgba())

但是,使用这种方法,我不确定如何在转换为数组之前消除绘图周围的空间.是否有等同于 bbox_inches ='tight' pad_inches = 0 的等价物,而无需使用 plt.savefig()?

However, with this approach, I'm not sure how to eliminate the space around the plot before converting to an array. Is there an equivalent to bbox_inches='tight' and pad_inches = 0 that doesn't involve using plt.savefig()?

推荐答案

改进的答案

这似乎适用于您的情况,应该很快.可能会有更好的方法-如果有人知道更好,我很乐意将其删除:

This seems to work for your case and should be fast. There may be better ways - I am happy to delete it if anyone knows something better:

#!/usr/bin/env python3

from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvas
import numpy as np

fig, ax = plt.subplots()
canvas = FigureCanvas(fig)
ax.plot([0,1], color='red', linewidth=4)
plt.xlim([0,1])
plt.ylim([0,1])
ax.set_aspect('equal', adjustable='box')
plt.axis('off')
canvas.draw()
X = np.array(canvas.renderer.buffer_rgba())

上面的代码是你的,下面的代码是我的:

The code above is yours, the code below is mine:

# Get width and height of cnvas for reshaping
w, h = canvas.get_width_height()
Y = np.frombuffer(X,dtype=np.uint8).reshape((h,w,4))[...,0:3]

# Work out extent of image by inverting and looking for black - ASSUMES CANVAS IS WHITE
extent = np.nonzero(~Y)

top    = extent[0].min()
bottom = extent[0].max()
left   = extent[1].min()
right  = extent[1].max()

tight_img = Y[top:bottom,left:right,:]

# Save as image just to test - you don't want this bit    
Image.fromarray(tight_img).save('tight.png')

原始答案

可能有更好的方法,但您可以通过写入基于内存的 BytesIO 来避免写入磁盘:

There may be a better way, but you could avoid writing to disk by writing to a memory-based BytesIO instead:

from io import BytesIO

buffer = BytesIO()
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches = 0)

然后做:

x = np.array(Image.open(buffer))

<小时>

实际上,如果您使用:


In fact, if you use:

plt.savefig(buffer, format='rgba', bbox_inches='tight', pad_inches = 0)

缓冲区已经拥有您的数组,您可以避免PNG编码/解码以及磁盘I/O.唯一的问题是,由于它是原始图像,因此我们不知道要 reshape()缓冲区的图像尺寸.它实际上是在我的机器上,但我通过编写 PNG 并检查其宽度和高度来获得尺寸:

the buffer already has your array and you can avoid the PNG encoding/decoding as well as the disk I/O. The only issue is that, because it is raw, we don't know the dimensions of the image to reshape() the buffer. It is actually this on my machine but I got the dimensions by writing a PNG and checking its width and height:

arr = buffer.getvalue()
x = np.frombuffer(arr, dtype=np.uint8).reshape((398,412,4))

如果有人想出更好的东西,我会删除它.

If someone comes up with something better, I'll delete this.

这篇关于将pyplot图形转换为数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆