首页 > 解决方案 > 使用 Python 避免在散点图中垂直重叠

问题描述

我正在模拟具有 3 臂和伯诺利回报的强盗问题中的 epsilon-greedy 算法。做完实验后,我想画出每条arm的return,也就是说,如果每次选择一个arm,对应时间所取的值就是它的return,剩下的2个arm,这个值就是设置为-1。现在我想根据时间段绘制一只手臂的回归。(该值将采用 -1 或 1 或 0)

import matplotlib.pyplot as plt
import random
from scipy import stats
class greedy():
    def __init__(self,epsilon,n):
        self.epsilon=epsilon
        self.n=n
        self.value=[0,0,0]#estimator
        self.count=[0,0,0]
        self.prob=[0.4,0.6,0.8]
        self.greedy_reward=[[0 for x in range(10000)] for y in range(3)]
    def exploration(self,i):
        max_index=np.random.choice([0,1,2])
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))#do experiment, return r
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def exploitation(self,i):
        max_index=self.value.index(max(self.value))
        r=np.random.choice([0,1],p=(1-self.prob[max_index],self.prob[max_index]))
        self.count[max_index]+=1
        for time in range(3):
            self.greedy_reward[time][i]=-1
        self.greedy_reward[max_index][i]=r
        self.value[max_index]=self.value[max_index]+(1/self.count[max_index])*(r-self.value[max_index])
    def EE_choice(self,i):
        output=np.random.choice(# o is exploitation,1 is exploration
        [0,1], 
        p=[1-self.epsilon,self.epsilon]
        )
        if output==1:
            self.exploration(i);
        else:
            self.exploitation(i);
    def exp(self):
        for i in range(0,self.n):

然后,我们取出一个手臂的回报,例如 arm3。

import matplotlib.pyplot as plt
x=[i for i in range(1,10001)]
arm_3_y=[0 for i in range(10000)]
for j in range(10000):
    arm_3_y[j]=greedy_1.greedy_reward[2][j]
plt.scatter(x,arm_3_y,marker='o')
plt.ylim([-1,1])
plt.show()

在此处输入图像描述

如我们所见,一条垂直线上的所有点都重叠在一起,有什么办法可以避免这种情况吗?

标签: pythonmatplotlibvisualization

解决方案


根据您想要可视化的内容,可以有很多方法来解决它。如果您想查看分布但不需要单个点,请使用箱线图。它会显示平均值、四分位数和范围。

如果您确实需要散点图并查看点,则为数据中的每个点添加一些随机性(仅用于可视化过程),它将减少数据重叠的机会,您可以看到它们更聚集的位置。

def randomize(arr):
    stdev = .01*min(arr) #use any small value, small enough to not change the distribution
    return arr + np.random.randn(len(arr)) * stdev
plt.scatter(x,randomize(arm_3_y),marker='o')

它应该有助于可视化。尝试使用系数(此处为 0.01)进行啮合以获得更多抖动。


推荐阅读