我的matplotlib脚本中的性能非常差 [英] Very poor performance in my matplotlib script

查看:52
本文介绍了我的matplotlib脚本中的性能非常差的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的代码在这里表现很差.更改滑块上的内容时,我的速度几乎不超过10 fps.当然,我对 matplotlib 不是很精通,但有人能指出我做错了什么以及如何解决吗?

My code here is performing very badly. I barely get more than 10 fps when changing things on the slider. Granted I am not very well-versed with matplotlib, but can someone point out what I am doing wrong and how to fix it?

注意:我正在处理大量数据,在最坏的情况下大约为 3*100000 点...也不确定是否需要这样做,但我在TkAgg"后端运行.

Note: I am handling a lot of data, around 3*100000 points in a worst case scenario... Also not sure if this is needed but I am running on the 'TkAgg' backend.

这是我的代码(它是绘制和运行 SIR 流行病学数学模型的代码):

Here is my code (it is a code to plot and run an SIR epidemiology mathematical model):

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import matplotlib.patches as patches

p = 1                                                       #population
i = 0.01*p                                                  #infected
s = p-i                                                     #susceptible
r = 0                                                       #recovered/removed

a = 3.2                                                     #transmission parameter
b = 0.23                                                    #recovery parameter

initialTime = 0
deltaTime = 0.001                                           #smaller the delta, better the approximation to a real derivative
maxTime = 10000                                             #more number of points, better is the curve generated

def sPrime(oldS, oldI, transmissionRate):                   #differential equations being expressed as functions to
    return -1*((transmissionRate*oldS*oldI)/p)              #calculate rate of change between time intervals of the
                                                            #different quantities i.e susceptible, infected and recovered/removed
def iPrime(oldS, oldI, transmissionRate, recoveryRate):             
    return (((transmissionRate*oldS)/p)-recoveryRate)*oldI

def rPrime(oldI, recoveryRate):
    return recoveryRate*oldI

maxTimeInitial = maxTime

def genData(transRate, recovRate, maxT):
    global a, b, maxTimeInitial
    a = transRate
    b = recovRate
    maxTimeInitial = maxT

    sInitial = s
    iInitial = i
    rInitial = r

    time = []
    sVals = []
    iVals = []
    rVals = []

    for t in range(initialTime, maxTimeInitial+1):              #generating the data through a loop
        time.append(t)
        sVals.append(sInitial)
        iVals.append(iInitial)
        rVals.append(rInitial)

        newDeltas = (sPrime(sInitial, iInitial, transmissionRate=a), iPrime(sInitial, iInitial, transmissionRate=a, recoveryRate=b), rPrime(iInitial, recoveryRate=b))
        sInitial += newDeltas[0]*deltaTime
        iInitial += newDeltas[1]*deltaTime
        rInitial += newDeltas[2]*deltaTime

        if sInitial < 0 or iInitial < 0 or rInitial < 0:        #as soon as any of these value become negative, the data generated becomes invalid
            break                                               #according to the SIR model, we assume all values of S, I and R are always positive.

    return (time, sVals, iVals, rVals)

fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.4, top=0.94)

plt.title('SIR epidemiology curves for a disease')

plt.xlim(0, maxTime+1)
plt.ylim(0, p*1.4)

plt.xlabel('Time (t)')
plt.ylabel('Population (p)')

initialData = genData(a, b, maxTimeInitial)

susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b')
infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r')
recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g')

plt.legend()

transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white')
recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white')
timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white')

transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01)
recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01)
timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i")

def updateTransmission(newVal):
    newData = genData(newVal, b, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateRecovery(newVal):
    newData = genData(a, newVal, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateMaxTime(newVal):
    global susceptible, infected, recovered

    newData = genData(a, b, int(newVal.item()))

    del ax.lines[:3]

    susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b')
    infected, = ax.plot(newData[0], newData[2], label='Infected', color='r')
    recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g')

transmissionSlider.on_changed(updateTransmission)
recoverySlider.on_changed(updateRecovery)
timeSlider.on_changed(updateMaxTime)

resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05])
resetButton = Button(resetAxes, 'Reset', color='white')

r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12)

def reset(event):
    transmissionSlider.reset()
    recoverySlider.reset()
    timeSlider.reset()

resetButton.on_clicked(reset)

plt.show()

推荐答案

使用适当的ODE求解器,例如 scipy.integrate.odeint 来提高速度.然后,您可以使用较大的时间步长进行输出.使用像 odeintsolve_ivpmethod="Radau" 这样的隐式求解器,作为精确解中边界的坐标平面也将是数值解,使值永远不会变成负数.

Use a proper ODE solver like scipy.integrate.odeint for speed. Then you can use larger time steps for the output. With an implicit solver like odeint or solve_ivp with method="Radau" the coordinate planes that are boundaries in the exact solution will also be boundaries in the numerical solution, so that the values never become negative.

减少绘图数据集以匹配绘图图像的实际分辨率.300点到1000点的差别可能还是看得见的,1000点到5000点就看不到差别了,甚至可能不是实际的差别.

Reduce the plotted data set to match the actual resolution of the plot image. The difference from 300 points to 1000 points may still be visible, there will be no visible difference from 1000 points to 5000 points, probably even not an actual difference.

matplotlib 使用缓慢的 Python 迭代通过场景树将其图像绘制为对象.如果要绘制多个10000个对象,这将非常慢,因此最好将详细信息的数量限制为该数量.

matplotlib draws its images via a scene tree as objects, using slow python iteration. This makes it very slow if there are more than a couple 10000 objects to draw, so it is best to limit the number of details to this number.

解决ODE我用的是solve_ivp,但是用odeint也没什么区别,

to solve the ODE I used solve_ivp, but it makes no difference if odeint is used,

def SIR_prime(t,SIR,trans, recov): # solver expects t argument, even if not used
    S,I,R = SIR
    dS = (-trans*I/p) * S 
    dI = (trans*S/p-recov) * I
    dR = recov*I
    return [dS, dI, dR]

def genData(transRate, recovRate, maxT):
    SIR = solve_ivp(SIR_prime, [0,maxT], [s,i,r], args=(transRate, recovRate), method="Radau", dense_output=True)
    time = np.linspace(0,SIR.t[-1],1001)
    sVals, iVals, rVals = SIR.sol(time)
    return (time, sVals, iVals, rVals)

情节更新程序的简化代码

一个人可以删除很多重复的代码.我还添加了一条线,以便时间轴随 maxTime 变量而变化,以便真正可以放大

Streamlined code for the plot update procdures

One can remove much of the duplicated code. I also added a line so that the time axis changes with the maxTime variable, so that one really can zoom in

def updateTransmission(newVal):
    global trans_rate
    trans_rate = newVal
    updatePlot()

def updateRecovery(newVal):
    global recov_rate
    recov_rate = newVal
    updatePlot()

def updateMaxTime(newVal):
    global maxTime
    maxTime = newVal
    updatePlot()

def updatePlot():
    newData = genData(trans_rate, recov_rate, maxTime)

    susceptible.set_data(newData[0],newData[1])
    infected.set_data(newData[0],newData[2])
    recovered.set_data(newData[0],newData[3])

    ax.set_xlim(0, maxTime+1)

    r_o.set_text(r'$R_O$={:.2f}'.format(trans_rate/recov_rate))

    fig.canvas.draw_idle()

中间和周围的代码保持不变.

The code in-between and around remains the same.

这篇关于我的matplotlib脚本中的性能非常差的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆