首页 > 解决方案 > Python Matplotlib Funcanimation 问题:难以在函数内对等高线图(contourf)进行动画处理

问题描述

我在 matplotlib.animation 中与 FuncAnimation 作斗争,我找不到任何与我的问题相似的示例或帖子(我的意思是,是的,有关于 funcAnimation 中使用的 contourf 的帖子,但在那些帖子中,他们成功删除了 PathCollection 对象但在我的情况下,有些东西不起作用)。

语境:

在一个关于 One-vs-All 概念(多个二元分类器)的学校项目中,我想实现一个函数来为具有 3 个 Axes 并包含多个 Line2D 对象来自 scatter方法的 PathCollection 对象和来自 contourf 方法的 QuadContourSet 设置动画。

这是它的外观屏幕(当我在 One-vs-All 训练结束时绘制数据时获得):

静态图的表示

传奇:

方法:

我正在尝试使用来自 matplotlib.Animation 模块的 FuncAnimation 制作情节的动画版本。情节的动画版是我项目的一个额外功能,然后动画部分/核心是在函数中制作的,你可以在下面看到一个简化(一个简单的骨骼表示):

def anim_visu(models, data):
    # initialization of the figure and object representing the data
    ...

    def f_anim():
        # Function which update the data at each frames

    visu = FuncAnimation(fig, f_anim, ...)

    return fig

[...]

if __name__ == "__main__":
    [...]
    if bool_dynamic: # activation of the dynamic visualization
        anim_visu(models, data)

这是一个最小的工作示例:

# =========================================================================== #
#                       |Importation des lib/packages|                        #
# =========================================================================== #
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec

dct_palet = {"C1":"dodgerblue",
             "C2":"red",
             "C3":"green",
             "C4":"goldenrod"}
fps = 15

# =========================================================================== #
#                        | Definition des fonctions |                         #
# =========================================================================== #

def one_vs_all_prediction(classifiers:list, X:np.array) -> np.array:
    """
    ... Docstring ...
    """
    preds = np.zeros((X.shape[0],1))

    for clf in classifiers:
        tmp = clf.predict(X)
        mask = preds == 0
        preds[mask] = tmp[mask]
    
    return preds


def one_vs_all_class_onehot(class_pred:np.array):
    """
    ... Docstring ...
    """
    house = {"C1":1., "C2":2., "C3":3., "C4":4.}
    onehot_pred = np.chararray((class_pred.shape[0],1), itemsize=2)
    
    for key, item in house.items():
        mask = class_pred == item
        onehot_pred[mask] = key
    
    return onehot_pred


def do_animation(clfs:list, data:np.ndarray):
    """ Core function for the animated vizualisation.
    The function defines all the x/y_labels, the titles.
    """

    global idx, cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
        met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
        met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
        boundary, axes, \
        l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
        l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4
    
    plt.style.use('seaborn-pastel')
    
    # -- Declaring the figure and the axes -- #
    fig = plt.figure(figsize=(15,9.5))
    gs = GridSpec(2, 2, figure=fig)
    axes = [fig.add_subplot(gs[:, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[1, 1])]

    # --formatting the different axes -- #
    axes[0].set_xlabel("X_1")
    axes[0].set_ylabel("X_2")
    axes[0].set_title("Decision boundary")
    axes[1].set_xlabel("i: iteration")
    axes[1].set_xlim(-10, 1000)
    axes[1].set_ylim(-10, 350)
    axes[1].set_ylabel(r"$\mathcal{L}_{\theta_0,\theta_1}$")
    axes[1].grid()
    axes[2].set_xlabel("i: iteration")
    axes[2].set_ylabel("Scores (metric_1 & metric_2)")
    axes[2].set_xlim(-10, 1000)
    axes[2].set_ylim(0.0,1.01)
    axes[2].grid()

    # -- Reading min and max values along X dimensions-- #
    X = data[:,0:2]
    X = X.astype(np.float64)
    Y = data[:,2].reshape(-1,1)

    idx = np.array([0])
    X_min, X_max = X[:,:2].min(axis=0), X[:,:2].max(axis=0)

    # -- Generate a grid of points with distance h between them -- #
    h = 0.01
    XX_1, XX_2 = np.meshgrid(np.arange(X_min[0], X_max[0], h),
                              np.arange(X_min[1], X_max[1], h))
    zeros_arr = np.zeros((XX_1.shape[0] * XX_1.shape[1], 1))
    XX = np.c_[XX_1.ravel(), XX_2.ravel(),
               zeros_arr.ravel(), zeros_arr.ravel(), zeros_arr.ravel()]

    # -- Predict the function value for the whole grid -- #
    preds = one_vs_all_prediction(clfs, XX)
    Z = preds.reshape(XX_1.shape)

    ## Initialisation of the PathCollection for the Axes[0] objects
    boundary = axes[0].contourf(XX_1, XX_2, Z, 3,
                                colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)

    lst_colors = np.array([dct_palet[house] for house in data[:,2]])
    raw_data = axes[0].scatter(X[:,0], X[:,1], c=lst_colors, edgecolor="k")

    ## Initialisation of the Line2D object for the Axes[1] objects
    cost_clf1 = clfs[0].cost()
    cost_clf2 = clfs[1].cost()
    cost_clf3 = clfs[2].cost()
    cost_clf4 = clfs[3].cost()
    l_cost_clf1, = axes[1].plot(idx, cost_clf1,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_cost_clf2, = axes[1].plot(idx, cost_clf2,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_cost_clf3, = axes[1].plot(idx, cost_clf3,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_cost_clf4, = axes[1].plot(idx, cost_clf4,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])

    ## Initialisation of the Line2D object for the Axes[2] objects
    met1_clf1 = clfs[0].dummy_metric1()
    met1_clf2 = clfs[1].dummy_metric1()
    met1_clf3 = clfs[2].dummy_metric1()
    met1_clf4 = clfs[3].dummy_metric1()
    met2_clf1 = clfs[0].dummy_metric2()
    met2_clf2 = clfs[1].dummy_metric2()
    met2_clf3 = clfs[2].dummy_metric2()
    met2_clf4 = clfs[3].dummy_metric2()
    l_met1_clf1, = axes[2].plot(idx, met1_clf1,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_met1_clf2, = axes[2].plot(idx, met1_clf2,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_met1_clf3, = axes[2].plot(idx, met1_clf3,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_met1_clf4, = axes[2].plot(idx, met1_clf4,
                                ls='-', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])
    l_met2_clf1, = axes[2].plot(idx, met2_clf1,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[0].house])
    l_met2_clf2, = axes[2].plot(idx, met2_clf2,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[1].house])
    l_met2_clf3, = axes[2].plot(idx, met2_clf3,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[2].house])
    l_met2_clf4, = axes[2].plot(idx, met2_clf4,
                                ls='--', marker='o', ms=2, lw=1.2, color=dct_palet[clfs[3].house])

    fig.canvas.mpl_connect('close_event', f_close)
    anim_fig = FuncAnimation(fig, f_animate, fargs=(XX_1, XX_2, XX,), frames=int(1000/fps), repeat=False, cache_frame_data = False, blit=False)
    plt.waitforbuttonpress()

    return fig


def f_animate(i, XX_1, XX_2, XX):
    """
    ... Docstring ...
    """
    global clfs, idx, \
        cost_clf1, cost_clf2, cost_clf3, cost_clf4, \
        met1_clf1, met1_clf2, met1_clf3, met1_clf4, \
        met2_clf1, met2_clf2, met2_clf3, met2_clf4, \
        boundary, axes, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
        l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4

    n_cycle = 100
    clfs[0].fit(n_cycle)
    clfs[1].fit(n_cycle)
    clfs[2].fit(n_cycle)
    clfs[3].fit(n_cycle)
    
    idx = np.concatenate((idx, np.array([i * n_cycle])))

    preds = one_vs_all_prediction(clfs, XX)
    Z = preds.reshape(XX_1.shape)

    cost_clf1 = np.concatenate((cost_clf1, clfs[0].cost()))
    cost_clf2 = np.concatenate((cost_clf2, clfs[1].cost()))
    cost_clf3 = np.concatenate((cost_clf3, clfs[2].cost()))
    cost_clf4 = np.concatenate((cost_clf4, clfs[3].cost()))

    tmp_met1_clf1 = clfs[0].dummy_metric1()
    tmp_met1_clf2 = clfs[1].dummy_metric1()
    tmp_met1_clf3 = clfs[2].dummy_metric1()
    tmp_met1_clf4 = clfs[3].dummy_metric1()
    tmp_met2_clf1 = clfs[0].dummy_metric2()
    tmp_met2_clf2 = clfs[1].dummy_metric2()
    tmp_met2_clf3 = clfs[2].dummy_metric2()
    tmp_met2_clf4 = clfs[3].dummy_metric2()

    met1_clf1 = np.concatenate((met1_clf1, tmp_met1_clf1))
    met1_clf2 = np.concatenate((met1_clf2, tmp_met1_clf2))
    met1_clf3 = np.concatenate((met1_clf3, tmp_met1_clf3))
    met1_clf4 = np.concatenate((met1_clf4, tmp_met1_clf4))
    met2_clf1 = np.concatenate((met2_clf1, tmp_met2_clf1))
    met2_clf2 = np.concatenate((met2_clf2, tmp_met2_clf2))
    met2_clf3 = np.concatenate((met2_clf3, tmp_met2_clf3))
    met2_clf4 = np.concatenate((met2_clf4, tmp_met2_clf4))

    # -- Plot the contour and training examples -- #

    # Update the plot objects: remove the previous collections to save memory.
    #l = len(boundary.collections)
    for coll in boundary.collections:
    # Remove the existing contours
        boundary.collections.remove(coll)

    boundary = axes[0].contourf(XX_1, XX_2, Z, 3, colors=["red", "green", "goldenrod", "dodgerblue"], alpha=0.5)

    l_cost_clf1.set_data(idx, cost_clf1)
    l_cost_clf2.set_data(idx, cost_clf2)
    l_cost_clf3.set_data(idx, cost_clf3)
    l_cost_clf4.set_data(idx, cost_clf4)

    l_met1_clf1.set_data(idx, met1_clf1)
    l_met1_clf2.set_data(idx, met1_clf2)
    l_met1_clf3.set_data(idx, met1_clf3)
    l_met1_clf4.set_data(idx, met1_clf4)
    l_met2_clf1.set_data(idx, met2_clf1)
    l_met2_clf2.set_data(idx, met2_clf2)
    l_met2_clf3.set_data(idx, met2_clf3)
    l_met2_clf4.set_data(idx, met2_clf4)

    return boundary.collections, l_cost_clf1, l_cost_clf2, l_cost_clf3, l_cost_clf4, \
        l_met1_clf1, l_met1_clf2, l_met1_clf3, l_met1_clf4, \
            l_met2_clf1, l_met2_clf2, l_met2_clf3, l_met2_clf4


def f_close(event):
    """ Functions called when the graphical window is closed.
    It prints the last value of the theta vector and the last value of the
    cost function.
    """
    plt.close()

class DummyBinary():
    def __init__(self, house, theta0, theta1, alpha=1e-3):
        self.house = house
        self.theta0 = theta0
        self.theta1 = theta1
        self.alpha = alpha
        if self.house == "C1":
            self.border_x = 6
            self.border_y = 6
        if self.house == "C2":
            self.border_x = 6
            self.border_y = 13
        if self.house == "C3":
            self.border_x = 13
            self.border_y = 6
        if self.house == "C4":
            self.border_x = 13
            self.border_y = 13
    

    def fit(self, n_cycle:int):
        for _ in range(n_cycle):
            self.theta0 = self.theta0 + self.alpha * (self.border_x - self.theta0)
            self.theta1 = self.theta1 + self.alpha * (self.border_y - self.theta1)
    

    def cost(self) -> float:
        cost = (self.theta0 - self.border_x)**2 + (self.theta1 - self.border_y)**2
        return cost
    

    def predict(self, X:np.array) -> np.array:
        if self.house == 'C1':
            mask = (X[:,0] < self.theta0) & (X[:,1] < self.theta1)
        if self.house == 'C2':
            mask = (X[:,0] < self.theta0) & (X[:,1] > self.theta1)
        if self.house == 'C3':
            mask = (X[:,0] > self.theta0) & (X[:,1] < self.theta1)
        if self.house == 'C4':
            mask = (X[:,0] > self.theta0) & (X[:,1] > self.theta1)
        pred =np.zeros((X.shape[0], 1))
        pred[mask] = int(self.house[1])
        return pred


    def dummy_metric1(self):
        return np.array([0.5 * (self.theta0 / self.border_x + self.theta1 / self.border_y)])


    def dummy_metric2(self):
        return np.array([0.5 * ((self.theta0 / self.border_x)**2 + (self.theta1 / self.border_y)**2)])

# =========================================================================== #
# _________________________________  MAIN  __________________________________ #
# =========================================================================== #

if __name__ == "__main__":
    # -- Dummy data -- #
    x1 = np.random.randn(60,1) * 2.5 + 3.5
    x2 = np.random.randn(60,1) * 2.5 + 3.5
    x3 = np.random.randn(60,1) * 2.5 + 15.5
    x4 = np.random.randn(60,1) * 2.5 + 15.5
    stud_house = 60 * ['C1'] + 60 * ['C2'] + 60 * ['C3'] + 60 * ['C4']
    c_house = [dct_palet[house] for house in stud_house]

    y1 = np.random.randn(60,1) * 2.5 + 3.5
    y2 = np.random.randn(60,1) * 2.5 + 15.5
    y3 = np.random.randn(60,1) * 2.5 + 3.5
    y4 = np.random.randn(60,1) * 2.5 + 15.5
    
    X = np.concatenate((x1, x2, x3, x4)) # shape: (240,1)
    Y = np.concatenate((y1, y2, y3, y4)) # shape: (240,1)
    data = np.concatenate((X, Y, np.array(stud_house).reshape(-1,1)), axis=1)  # shape: (240,3)

    clf1 = DummyBinary("C1", np.random.rand(1), np.random.rand(1))
    clf2 = DummyBinary("C2", np.random.rand(1), np.random.rand(1))
    clf3 = DummyBinary("C3", np.random.rand(1), np.random.rand(1))
    clf4 = DummyBinary("C4", np.random.rand(1), np.random.rand(1))
    clfs = [clf1, clf2, clf3, clf4]

    ## Visualize the raw dummy data.
    #plt.scatter(X, Y, c=c_house, s=5)
    #plt.show()

    do_animation(clfs, data)

DummyBinary 类以简化的方式模仿我的 One-vs-All 类可以做什么。您可以在anim_visuf_anim中看到一堆全局代码,这样代码“有效”,但我知道有一些非常错误的地方。

尝试:

  1. 没有全局变量,所有内容都传递给f_animvia fargs,但是当从 返回时f_anim,范围内变量的所有修改f_anim都丢失了(显然是正常行为),
  2. 移动 inside 的定义f_animanim_visu创建f_anim一个inner_function。我没有足够的经验,所以我没有成功使它以这种方式工作,我注意到它可能看起来不可能修改anim_visu内部函数范围内声明的变量。
  3. 将我需要的所有变量声明为全局变量,它以某种方式工作,但正如您通过运行代码(在轴 [0] 中)所看到的,PathCollections 不会被清除/删除(尽管使用 循环boundary.collections.remove(coll))和数量轴 [0] 中的 PathCollection 似乎增加了,导致帧更新速度下降。

期待您的建议(以及我希望的解决方案+解释)。感谢您的时间和神经元。

标签: pythonmatplotlibmatplotlib-animation

解决方案


推荐阅读