首页 > 技术文章 > 画出8个高斯分布散点图

gaona666 2020-03-09 09:11 原文

import matplotlib.pyplot as plt
import numpy as np

num_mixtures = 8
radius = 2.0
std = 0.02
thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures]
xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
mix_coeffs=tuple([1 / num_mixtures] * num_mixtures)
mean=tuple(zip(xs, ys))
cov=tuple([(std, std)] * num_mixtures)
ax = None
epoch = 0
fig = None
        

def gmm_sample(num_samples, mix_coeffs, mean, cov):
    z = np.random.multinomial(num_samples, mix_coeffs)
    samples = np.zeros(shape=[num_samples, len(mean[0])])
    i_start = 0
    for i in range(len(mix_coeffs)):
        i_end = i_start + z[i]
        samples[i_start:i_end, :] = np.random.multivariate_normal(
            mean=np.array(mean)[i, :],
            cov=np.diag(np.array(cov)[i, :]),
            size=z[i])
        i_start = i_end
    return samples

def disp_scatter(x, fig=None, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    ax.scatter(x[:, 0], x[:, 1], s=10, marker='+', color='r', alpha=0.8, label='real data')
    
    ax.legend()
    return fig, ax
num_samples=1000

x = gmm_sample(num_samples, mix_coeffs, mean, cov)
               
fig, ax = disp_scatter(x, fig=None, ax=None)
fig.tight_layout()
fig.savefig("output\{}.png".format(epoch))

num_mixtures = 8

 

 num_mixtures = 1

 

推荐阅读