首页 > 解决方案 > 我的 matplotlib 脚本的性能很差

问题描述

我的代码在这里表现非常糟糕。在滑块上更改内容时,我几乎没有超过 10 fps。当然我不是很精通matplotlib,但是有人可以指出我做错了什么以及如何解决它吗?

注意:我正在处理大量数据,在最坏的情况下大约是 3*100000 点......也不确定是否需要这样做,但我在“TkAgg”后端运行。

这是我的代码(它是绘制和运行 SIR 流行病学数学模型的代码):

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import matplotlib.patches as patches

p = 1                                                       #population
i = 0.01*p                                                  #infected
s = p-i                                                     #susceptible
r = 0                                                       #recovered/removed

a = 3.2                                                     #transmission parameter
b = 0.23                                                    #recovery parameter

initialTime = 0
deltaTime = 0.001                                           #smaller the delta, better the approximation to a real derivative
maxTime = 10000                                             #more number of points, better is the curve generated

def sPrime(oldS, oldI, transmissionRate):                   #differential equations being expressed as functions to
    return -1*((transmissionRate*oldS*oldI)/p)              #calculate rate of change between time intervals of the
                                                            #different quantities i.e susceptible, infected and recovered/removed
def iPrime(oldS, oldI, transmissionRate, recoveryRate):             
    return (((transmissionRate*oldS)/p)-recoveryRate)*oldI

def rPrime(oldI, recoveryRate):
    return recoveryRate*oldI

maxTimeInitial = maxTime

def genData(transRate, recovRate, maxT):
    global a, b, maxTimeInitial
    a = transRate
    b = recovRate
    maxTimeInitial = maxT

    sInitial = s
    iInitial = i
    rInitial = r

    time = []
    sVals = []
    iVals = []
    rVals = []

    for t in range(initialTime, maxTimeInitial+1):              #generating the data through a loop
        time.append(t)
        sVals.append(sInitial)
        iVals.append(iInitial)
        rVals.append(rInitial)

        newDeltas = (sPrime(sInitial, iInitial, transmissionRate=a), iPrime(sInitial, iInitial, transmissionRate=a, recoveryRate=b), rPrime(iInitial, recoveryRate=b))
        sInitial += newDeltas[0]*deltaTime
        iInitial += newDeltas[1]*deltaTime
        rInitial += newDeltas[2]*deltaTime

        if sInitial < 0 or iInitial < 0 or rInitial < 0:        #as soon as any of these value become negative, the data generated becomes invalid
            break                                               #according to the SIR model, we assume all values of S, I and R are always positive.

    return (time, sVals, iVals, rVals)

fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.4, top=0.94)

plt.title('SIR epidemiology curves for a disease')

plt.xlim(0, maxTime+1)
plt.ylim(0, p*1.4)

plt.xlabel('Time (t)')
plt.ylabel('Population (p)')

initialData = genData(a, b, maxTimeInitial)

susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b')
infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r')
recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g')

plt.legend()

transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white')
recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white')
timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white')

transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01)
recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01)
timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i")

def updateTransmission(newVal):
    newData = genData(newVal, b, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateRecovery(newVal):
    newData = genData(a, newVal, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateMaxTime(newVal):
    global susceptible, infected, recovered

    newData = genData(a, b, int(newVal.item()))

    del ax.lines[:3]

    susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b')
    infected, = ax.plot(newData[0], newData[2], label='Infected', color='r')
    recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g')

transmissionSlider.on_changed(updateTransmission)
recoverySlider.on_changed(updateRecovery)
timeSlider.on_changed(updateMaxTime)

resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05])
resetButton = Button(resetAxes, 'Reset', color='white')

r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12)

def reset(event):
    transmissionSlider.reset()
    recoverySlider.reset()
    timeSlider.reset()

resetButton.on_clicked(reset)

plt.show()

标签: pythonmatplotlibplot

解决方案


使用适当的 ODE 求解器,例如scipy.integrate.odeint速度。然后您可以对输出使用更大的时间步长。使用隐式求解器或odeint坐标平面,精确解中的边界也将是数值解中的边界,因此值永远不会变为负数。solve_ivpmethod="Radau"

减少绘制的数据集以匹配绘图图像的实际分辨率。从 300 点到 1000 点的差异可能仍然可见,从 1000 点到 5000 点不会有明显的差异,甚至可能不是实际差异。

matplotlib 使用缓慢的 python 迭代通过场景树将其图像绘制为对象。如果要绘制超过 10000 个对象,这会导致速度非常慢,因此最好将细节的数量限制在这个数量之内。

ODE 求解器的代码

为了求解 ODE,我使用了 solve_ivp,但如果使用 odeint 则没有区别,

def SIR_prime(t,SIR,trans, recov): # solver expects t argument, even if not used
    S,I,R = SIR
    dS = (-trans*I/p) * S 
    dI = (trans*S/p-recov) * I
    dR = recov*I
    return [dS, dI, dR]

def genData(transRate, recovRate, maxT):
    SIR = solve_ivp(SIR_prime, [0,maxT], [s,i,r], args=(transRate, recovRate), method="Radau", dense_output=True)
    time = np.linspace(0,SIR.t[-1],1001)
    sVals, iVals, rVals = SIR.sol(time)
    return (time, sVals, iVals, rVals)

情节更新过程的简化代码

可以删除大部分重复的代码。我还添加了一条线,以便时间轴随 maxTime 变量而变化,这样就可以真正放大

def updateTransmission(newVal):
    global trans_rate
    trans_rate = newVal
    updatePlot()

def updateRecovery(newVal):
    global recov_rate
    recov_rate = newVal
    updatePlot()

def updateMaxTime(newVal):
    global maxTime
    maxTime = newVal
    updatePlot()

def updatePlot():
    newData = genData(trans_rate, recov_rate, maxTime)

    susceptible.set_data(newData[0],newData[1])
    infected.set_data(newData[0],newData[2])
    recovered.set_data(newData[0],newData[3])

    ax.set_xlim(0, maxTime+1)

    r_o.set_text(r'$R_O$={:.2f}'.format(trans_rate/recov_rate))

    fig.canvas.draw_idle()

中间和周围的代码保持不变。


推荐阅读