python - TensorFlow Probability MCMC with Bernoulli distribution
问题描述
I need to use TensorFlow Probability to implement Markov Chain Monte Carlo with sampling from a Bernoulli distribution. However, my attempts are showing results not consistent with what I would expect from a Bernoulli distribution.
I modified the example given in the documentation of tfp.mcmc.sample_chain (sampling from a diagonal-variance Gaussian) example here to draw from a Bernoulli distribution. Since the Bernoulli distribution is discrete, I used the RandomWalkMetropolis transition kernel instead of the Hamiltonian Monte Carlo kernel, which I expect would not work since it computes a gradient.
Here is the code:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def make_likelihood(event_prob):
return tfd.Bernoulli(probs=event_prob,dtype=tf.float32)
dims=1
event_prob = 0.3
num_results = 30000
likelihood = make_likelihood(event_prob)
states, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
current_state=tf.zeros(dims),
kernel = tfp.mcmc.RandomWalkMetropolis(
target_log_prob_fn=likelihood.log_prob,
new_state_fn=tfp.mcmc.random_walk_normal_fn(scale=1.0),
seed=124
),
num_burnin_steps=5000)
chain_vals = states
# Compute sample stats.
sample_mean = tf.reduce_mean(states, axis=0)
sample_var = tf.reduce_mean(
tf.squared_difference(states, sample_mean),
axis=0)
#initialize the variable
init_op = tf.global_variables_initializer()
#run the graph
with tf.Session() as sess:
sess.run(init_op)
[sample_mean_, sample_var_, chain_vals_] = sess.run([sample_mean,sample_var,chain_vals])
chain_samples = (chain_vals_[:] )
print ('Sample mean = {}'.format(sample_mean_))
print ('Sample var = {}'.format(sample_var_))
fig, axes = plt.subplots(2, 1)
fig.set_size_inches(12, 10)
axes[0].plot(chain_samples[:])
axes[0].title.set_text("values sample chain tfd.Bernoulli")
sns.kdeplot(chain_samples[:,0], ax=axes[1], shade=True)
axes[1].title.set_text("chain tfd.Bernoulli distribution")
fig.tight_layout()
plt.show()
I expected to see values for the states of the Markov chain in the interval [0,1].
The result for the Markov chain values does not look like what is expected for a Bernoulli distribution, nor does the KDE plot, as shown in this figure:
Do I have a conceptual flaw with my example, or mistake in using the TensorFlow Probability API ?
Or is there possibly an issue with the TF.Probability implementation of Markov Chain Monte Carlo using a discrete distribution such as the Bernoulli distribution?
解决方案
我认为您令人困惑的经历的根源在于您在RandomWalkMetropolis
过渡期间仍在使用连续提案分发。TensorFlow Probability 中整数分布(包括Bernoulli
)的约定是默认实现连续松弛。IIRC,对于伯努利来说,就是pdf(x) ~ p ** x * (1 - p) ** (1 - x)
; 正如您所观察到的,一旦x
变为负数,这将稳定地将您的随机游走马尔可夫链推向。-inf
您可以对此做几件事:
- 使用传递
validate_args=True
给Bernoulli
构造函数。如果不是 0 或 1,这将崩溃x
,帮助您检测问题(但如果您想要区间 [0, 1] 中的非整数结果,请不要这样做)。 - 使用不同的提案函数,例如 0 和 1 之间的独立统一。编写自己的提案并不难——这是您使用的高斯漂移提案函数的代码:https ://github.com/tensorflow/概率/blob/master/tensorflow_probability/python/mcmc/random_walk_metropolis.py#L97-L107。请注意,提案需要对称才能与
RandomWalkMetropolis
. - 完全使用不同的 MCMC 转换运算符。
我还提交了一张关于为独立提案制作 TransitionKernel 的票(比如我想你可能需要的):https ://github.com/tensorflow/probability/issues/218
推荐阅读
- python - Python 3 打印样式
- swift - Swift 存储所有随机生成且不会再次生成的数字
- javascript - 使用 xhr 发送帖子请求
- java - 将 docker 镜像部署到 kubernetes 的问题
- php - MySQLi 查询不执行
- c# - Outlook 插件 - 新约会事件处理程序
- python - Python/Dash by Plotly:如何编写一个包含参数的 csv 来对正在写入的数据集的列中的所有日期进行排序?
- java - 将数学公式输入java
- php - PHP致命错误:未捕获的PDOException:SQLSTATE [42000]:语法错误或访问冲突:1064您的SQL语法有错误
- javascript - 更新MongoDb usimg expressjs,mongoose,body-parser中的每个项目?