我的matplotlib脚本中的性能非常差 [英] Very poor performance in my matplotlib script
问题描述
我的代码在这里表现很差.更改滑块上的内容时,我的速度几乎不超过10 fps.当然,我对 matplotlib 不是很精通,但有人能指出我做错了什么以及如何解决吗?
My code here is performing very badly. I barely get more than 10 fps when changing things on the slider. Granted I am not very well-versed with matplotlib, but can someone point out what I am doing wrong and how to fix it?
注意:我正在处理大量数据,在最坏的情况下大约为 3*100000 点...也不确定是否需要这样做,但我在TkAgg"后端运行.
Note: I am handling a lot of data, around 3*100000 points in a worst case scenario... Also not sure if this is needed but I am running on the 'TkAgg' backend.
这是我的代码(它是绘制和运行 SIR 流行病学数学模型的代码):
Here is my code (it is a code to plot and run an SIR epidemiology mathematical model):
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import matplotlib.patches as patches
p = 1 #population
i = 0.01*p #infected
s = p-i #susceptible
r = 0 #recovered/removed
a = 3.2 #transmission parameter
b = 0.23 #recovery parameter
initialTime = 0
deltaTime = 0.001 #smaller the delta, better the approximation to a real derivative
maxTime = 10000 #more number of points, better is the curve generated
def sPrime(oldS, oldI, transmissionRate): #differential equations being expressed as functions to
return -1*((transmissionRate*oldS*oldI)/p) #calculate rate of change between time intervals of the
#different quantities i.e susceptible, infected and recovered/removed
def iPrime(oldS, oldI, transmissionRate, recoveryRate):
return (((transmissionRate*oldS)/p)-recoveryRate)*oldI
def rPrime(oldI, recoveryRate):
return recoveryRate*oldI
maxTimeInitial = maxTime
def genData(transRate, recovRate, maxT):
global a, b, maxTimeInitial
a = transRate
b = recovRate
maxTimeInitial = maxT
sInitial = s
iInitial = i
rInitial = r
time = []
sVals = []
iVals = []
rVals = []
for t in range(initialTime, maxTimeInitial+1): #generating the data through a loop
time.append(t)
sVals.append(sInitial)
iVals.append(iInitial)
rVals.append(rInitial)
newDeltas = (sPrime(sInitial, iInitial, transmissionRate=a), iPrime(sInitial, iInitial, transmissionRate=a, recoveryRate=b), rPrime(iInitial, recoveryRate=b))
sInitial += newDeltas[0]*deltaTime
iInitial += newDeltas[1]*deltaTime
rInitial += newDeltas[2]*deltaTime
if sInitial < 0 or iInitial < 0 or rInitial < 0: #as soon as any of these value become negative, the data generated becomes invalid
break #according to the SIR model, we assume all values of S, I and R are always positive.
return (time, sVals, iVals, rVals)
fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.4, top=0.94)
plt.title('SIR epidemiology curves for a disease')
plt.xlim(0, maxTime+1)
plt.ylim(0, p*1.4)
plt.xlabel('Time (t)')
plt.ylabel('Population (p)')
initialData = genData(a, b, maxTimeInitial)
susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b')
infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r')
recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g')
plt.legend()
transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white')
recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white')
timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white')
transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01)
recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01)
timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i")
def updateTransmission(newVal):
newData = genData(newVal, b, maxTimeInitial)
susceptible.set_ydata(newData[1])
infected.set_ydata(newData[2])
recovered.set_ydata(newData[3])
r_o.set_text(r'$R_O$={:.2f}'.format(a/b))
fig.canvas.draw_idle()
def updateRecovery(newVal):
newData = genData(a, newVal, maxTimeInitial)
susceptible.set_ydata(newData[1])
infected.set_ydata(newData[2])
recovered.set_ydata(newData[3])
r_o.set_text(r'$R_O$={:.2f}'.format(a/b))
fig.canvas.draw_idle()
def updateMaxTime(newVal):
global susceptible, infected, recovered
newData = genData(a, b, int(newVal.item()))
del ax.lines[:3]
susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b')
infected, = ax.plot(newData[0], newData[2], label='Infected', color='r')
recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g')
transmissionSlider.on_changed(updateTransmission)
recoverySlider.on_changed(updateRecovery)
timeSlider.on_changed(updateMaxTime)
resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05])
resetButton = Button(resetAxes, 'Reset', color='white')
r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12)
def reset(event):
transmissionSlider.reset()
recoverySlider.reset()
timeSlider.reset()
resetButton.on_clicked(reset)
plt.show()
推荐答案
使用适当的ODE求解器,例如 scipy.integrate.odeint
来提高速度.然后,您可以使用较大的时间步长进行输出.使用像 odeint
或 solve_ivp
和 method="Radau"
这样的隐式求解器,作为精确解中边界的坐标平面也将是数值解,使值永远不会变成负数.
Use a proper ODE solver like scipy.integrate.odeint
for speed. Then you can use larger time steps for the output. With an implicit solver like odeint
or solve_ivp
with method="Radau"
the coordinate planes that are boundaries in the exact solution will also be boundaries in the numerical solution, so that the values never become negative.
减少绘图数据集以匹配绘图图像的实际分辨率.300点到1000点的差别可能还是看得见的,1000点到5000点就看不到差别了,甚至可能不是实际的差别.
Reduce the plotted data set to match the actual resolution of the plot image. The difference from 300 points to 1000 points may still be visible, there will be no visible difference from 1000 points to 5000 points, probably even not an actual difference.
matplotlib 使用缓慢的 Python 迭代通过场景树将其图像绘制为对象.如果要绘制多个10000个对象,这将非常慢,因此最好将详细信息的数量限制为该数量.
matplotlib draws its images via a scene tree as objects, using slow python iteration. This makes it very slow if there are more than a couple 10000 objects to draw, so it is best to limit the number of details to this number.
解决ODE我用的是solve_ivp,但是用odeint也没什么区别,
to solve the ODE I used solve_ivp, but it makes no difference if odeint is used,
def SIR_prime(t,SIR,trans, recov): # solver expects t argument, even if not used
S,I,R = SIR
dS = (-trans*I/p) * S
dI = (trans*S/p-recov) * I
dR = recov*I
return [dS, dI, dR]
def genData(transRate, recovRate, maxT):
SIR = solve_ivp(SIR_prime, [0,maxT], [s,i,r], args=(transRate, recovRate), method="Radau", dense_output=True)
time = np.linspace(0,SIR.t[-1],1001)
sVals, iVals, rVals = SIR.sol(time)
return (time, sVals, iVals, rVals)
情节更新程序的简化代码
一个人可以删除很多重复的代码.我还添加了一条线,以便时间轴随 maxTime 变量而变化,以便真正可以放大
Streamlined code for the plot update procdures
One can remove much of the duplicated code. I also added a line so that the time axis changes with the maxTime variable, so that one really can zoom in
def updateTransmission(newVal):
global trans_rate
trans_rate = newVal
updatePlot()
def updateRecovery(newVal):
global recov_rate
recov_rate = newVal
updatePlot()
def updateMaxTime(newVal):
global maxTime
maxTime = newVal
updatePlot()
def updatePlot():
newData = genData(trans_rate, recov_rate, maxTime)
susceptible.set_data(newData[0],newData[1])
infected.set_data(newData[0],newData[2])
recovered.set_data(newData[0],newData[3])
ax.set_xlim(0, maxTime+1)
r_o.set_text(r'$R_O$={:.2f}'.format(trans_rate/recov_rate))
fig.canvas.draw_idle()
中间和周围的代码保持不变.
The code in-between and around remains the same.
这篇关于我的matplotlib脚本中的性能非常差的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!