首页 > 解决方案 > 如何在 google colab 中创建实时 matplotlib.pyplot 图?

问题描述

不幸的是,无法%matplotlib notebook像在我的 PC 上的离线 jupyter 笔记本中那样在 google colab 笔记本中创建实时绘图。

我发现了两个类似的问题,回答了如何为情节图(link_1link_2)实现这一点。但是我无法使其适应 matplotlib,或者根本不知道这是否可能。

我在这里遵循本教程中的代码:GitHub 链接。特别是我想运行这段代码,它会创建一个回调,在训练步骤上绘制每一步的奖励:

import matplotlib.pyplot as plt
import numpy as np
%matplotlib notebook


class PlottingCallback(BaseCallback):
    """
    Callback for plotting the performance in realtime.

    :param verbose: (int)
    """
    def __init__(self, verbose=1):
        super(PlottingCallback, self).__init__(verbose)
        self._plot = None

    def _on_step(self) -> bool:
        # get the monitor's data
        x, y = ts2xy(load_results(log_dir), 'timesteps')
      if self._plot is None: # make the plot
          plt.ion()
          fig = plt.figure(figsize=(6,3))
          ax = fig.add_subplot(111)
          line, = ax.plot(x, y)
          self._plot = (line, ax, fig)
          plt.show()
      else: # update and rescale the plot
          self._plot[0].set_data(x, y)
          self._plot[-2].relim()
          self._plot[-2].set_xlim([self.locals["total_timesteps"] * -0.02, 
                                   self.locals["total_timesteps"] * 1.02])
          self._plot[-2].autoscale_view(True,True,True)
          self._plot[-1].canvas.draw()

# Create log dir
log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)

# Create and wrap the environment
env = make_vec_env('MountainCarContinuous-v0', n_envs=1, monitor_dir=log_dir)

plotting_callback = PlottingCallback()

model = PPO2('MlpPolicy', env, verbose=0)
model.learn(20000, callback=plotting_callback)

标签: pythonmatplotlibgoogle-colaboratoryopenai-gymstable-baselines

解决方案


您可以使用的 hack 是使用您在 jupyter 笔记本上使用的相同代码,创建一个按钮,然后使用 JavaScript 来单击该按钮,欺骗前端请求更新,以便它不断更新值。

这是一个使用 ipywidgets 的示例。

from IPython.display import display
import ipywidgets
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.2)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="This button does nothing... except send a\
 socket request to google servers to receive updated information since the \
 frontend wants to change..")

display(button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button').click()},
    1000);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
))

专门处理 matplotlib 有点复杂,我想我可以简单地在 asyncio 函数上调用 matplotlib plot ,但它确实滞后于更新,因为它似乎在没有人看到情节的背景中进行了不必要的渲染。所以另一种解决方法是更新按钮更新代码上的图。这段代码也受到了Add points to matlibplot scatter plot liveMatplotlib graphics to base64 的启发,因为没有必要为每个绘图创建绘图图形,您只需修改已有的图形即可。这当然意味着更多的代码。

from IPython.display import display
import ipywidgets
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
def pltToImg(plt):
 s = io.BytesIO()
 plt.savefig(s, format='png', bbox_inches="tight")
 s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
 #plt.close()
 return '<img align="left" src="data:image/png;base64,%s">' % s
progress = ipywidgets.FloatProgress(value=0.0, min=0.0, max=1.0)
import asyncio
async def work(progress):
    total = 100
    for i in range(total):
        await asyncio.sleep(0.5)
        progress.value = float(i+1)/total
display(progress)
asyncio.get_event_loop().create_task(work(progress))
button = ipywidgets.Button(description="Update =D")
a=ipywidgets.HTML(
    value="image here"
)
output = ipywidgets.Output()
plt.ion()
fig, ax = plt.subplots()
plot = ax.scatter([], [])
point = np.random.normal(0, 1, 2)
array = plot.get_offsets()
array = np.append(array, [point], axis=0)
plot.set_offsets(array)
plt.close()
ii=0
def on_button_clicked(b):
       global ii
       ii+=1
       point=np.r_[ii,np.random.normal(0, 1, 1)]
       array = plot.get_offsets()
       array = np.append(array, [point], axis=0)
       plot.set_offsets(array)
       ax.set_xlim(array[:, 0].min() - 0.5, array[:,0].max() + 0.5)
       ax.set_ylim(array[:, 1].min() - 0.5, array[:, 1].max() + 0.5)
       a.value=(pltToImg(fig))
       a.value+=str(progress.value)
       a.value+=" </br>"
       a.value+=str(ii)

button.on_click(on_button_clicked)
display(output,button,ipywidgets.HTML(
    value="""<script>
      var b=setInterval(a=>{
    //Hopefully this is the first button
    document.querySelector('#output-body button')?.click()},
    500);
    setTimeout(c=>clearInterval(b),1000*60*1);
    //Stops clicking the button after 1 minute
    </script>"""
),a)

推荐阅读