首页 > 解决方案 > 在 Matplotlib 中循环更新 Poly3DCollection 的颜色

问题描述

我想patchcollection = Poly3DCollection(patches,facecolor = colors)尽可能有效地更新循环中的颜色。在我的示例中,patches包含许多元素,我认为可以通过避免Poly3DCollection在每次迭代中调用来节省时间。这是期望结果的 MWE,以低效的方式实施:

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np 


def point():
    return tuple(np.random.rand(3))

n = 4                   # <-------- This is usually >1000
patches = []
for _ in range(n):
    thispatch = [point() for i in range(5)]
    patches.append(thispatch)  

def myplot1(patches,colors):
    fig = plt.figure(figsize=plt.figaspect(1)*0.7,constrained_layout=False) 
    ax = fig.gca(projection='3d')
    patchcollection = Poly3DCollection(patches,linewidth=0.1,edgecolor="k",facecolor = colors,rasterized=True)
    ax.add_collection3d(patchcollection)
    plt.show()

for _ in range(2):   # <------- This is usually >100
    colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]  # MWE colors  
    myplot1(patches,colors)

接下来,我只是想检查我是否能够通过Poly3DCollection而不是patches. 这有效,但还没有节省任何时间。

def myplot2(patchcollection,colors):
    fig = plt.figure(figsize=plt.figaspect(1)*0.7,constrained_layout=False) 
    ax = fig.gca(projection='3d')
    ax.add_collection3d(patchcollection)
    plt.show()


for _ in range(2):
    colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]   
    patchcollection = Poly3DCollection(patches,linewidth=0.1,edgecolor="k",facecolor = colors,rasterized=True)
    myplot2(patchcollection,colors)

但是,将Poly3DCollection循环移出循环会在第一次迭代后产生错误。

colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]   
patchcollection = Poly3DCollection(patches,linewidth=0.1,edgecolor="k",facecolor = colors,rasterized=True)

for _ in range(2):
    myplot2(patchcollection,colors)

错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-3df3a5c9df6b> in <module>
      3 
      4 for _ in range(2):
----> 5     myplot2(patchcollection,colors)
      6 

<ipython-input-5-37a10e307ad3> in myplot2(patchcollection, colors)
      2     fig = plt.figure(figsize=plt.figaspect(1)*0.7,constrained_layout=False)
      3     ax = fig.gca(projection='3d')
----> 4     ax.add_collection3d(patchcollection)
      5     plt.show()

~/Software/anaconda3/lib/python3.7/site-packages/mpl_toolkits/mplot3d/axes3d.py in add_collection3d(self, col, zs, zdir)
   2182             col.set_sort_zpos(zsortval)
   2183 
-> 2184         super().add_collection(col)
   2185 
   2186     def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True,

~/Software/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in add_collection(self, collection, autolim)
   1810         self.collections.append(collection)
   1811         collection._remove_method = self.collections.remove
-> 1812         self._set_artist_props(collection)
   1813 
   1814         if collection.get_clip_path() is None:

~/Software/anaconda3/lib/python3.7/site-packages/matplotlib/axes/_base.py in _set_artist_props(self, a)
    903     def _set_artist_props(self, a):
    904         """set the boilerplate props for artists added to axes"""
--> 905         a.set_figure(self.figure)
    906         if not a.is_transform_set():
    907             a.set_transform(self.transData)

~/Software/anaconda3/lib/python3.7/site-packages/matplotlib/artist.py in set_figure(self, fig)
    710         # to more than one Axes
    711         if self.figure is not None:
--> 712             raise RuntimeError("Can not put single artist in "
    713                                "more than one figure")
    714         self.figure = fig

RuntimeError: Can not put single artist in more than one figure

想要的结果将在没有错误的情况下执行以下操作:

colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]   
patchcollection = Poly3DCollection(patches,linewidth=0.1,edgecolor="k",facecolor = colors,rasterized=True)

for _ in range(2):
    myplot2(patchcollection,colors)
    colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]   
    patchcollection.updatecolor(colors) # I'm looking for update function here

标签: python-3.xmatplotlib

解决方案


更改颜色的最小示例将重用相同的图形并更新。我使用plt.ion了带有暂停的交互模式 ( ) 以允许重绘,并且您可以选择添加输入以阻止等待用户。

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import numpy as np 


def point():
    return tuple(np.random.rand(3))

n = 4                   # <-------- This is usually >1000
patches = []
for _ in range(n):
    thispatch = [point() for i in range(5)]
    patches.append(thispatch)  

fig = plt.figure(figsize=plt.figaspect(1)*0.7,constrained_layout=False) 
ax = fig.gca(projection='3d')
colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]  # MWE colors  
patchcollection = Poly3DCollection(patches,linewidth=0.1,edgecolor="k",facecolor = colors,rasterized=True)
ax.add_collection3d(patchcollection)
plt.ion()
plt.show()
for i in range(10):
    print(i)
    colors = ["tab:blue" if np.random.rand()<0.5 else "tab:orange" for patch in patches]  # MWE colors  
    patchcollection.set_color(colors)
    #input("Press Enter to redraw")
    plt.pause(0.01)

您得到的错误主要是因为您在每个循环中都重新创建了整个图形,这总是会慢得多。另外,不确定是否可能/相关,但请查看blittingmatplotlib。


推荐阅读