python - n-batch > 1 的多元正态分布
问题描述
我试图将如何在最新版本的 Tensorflow中使用 MultiVariateNormal 分布中给出的示例概括为二维但不止一批的正态分布。当我运行以下命令时:
from tensorflow_probability import distributions as tfd
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
mu = [[1, 2],
[-1,-2]]
cov = [[1, 3./5],
[3./5, 2]]
cov = [cov, cov] # for demonstration purpose, use same cov for both batches
mvn = tfd.MultivariateNormalFullCovariance(
loc=mu,
covariance_matrix=cov)
# generate the pdf
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)
prob = tf.reshape(mvn.prob(idx), tf.shape(X))
我收到不兼容的形状错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3600,2] vs. [2,2] [Op:Sub] name: MultivariateNormalFullCovariance/log_prob/affine_linear_operator/inverse/sub/
我对文档(https://www.tensorflow.org/api_docs/python/tf/contrib/distributions/MultivariateNormalFullCovariance)的理解是,要计算 pdf,需要一个 [n_observation, n_dimensions] 张量(在这个例子:idx.shape
= TensorShape([Dimension(3600), Dimension(2)])
)。我数学错了吗?
解决方案
您需要idx
在倒数第二个位置的张量中添加一个批处理轴,因为 60x60 不能针对mvn.batch_shape
of 进行广播(2,)
。
# TF/TFP Imports
!pip install --quiet tfp-nightly tf-nightly
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import tensorflow_probability as tfp
tfd = tfp.distributions
mu = [[1, 2],
[-1, -2]]
cov = [[1, 3./5],
[3./5, 2]]
cov = [cov, cov] # for demonstration purpose, use same cov for both batches
mvn = tfd.MultivariateNormalFullCovariance(
loc=mu, covariance_matrix=cov)
print(mvn.batch_shape, mvn.event_shape)
# generate the pdf
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
print(X.shape)
idx = tf.stack([X, Y], axis=-1)[..., tf.newaxis, :]
print(idx.shape)
probs = mvn.prob(idx)
print(probs.shape)
输出:
(2,) (2,) # mvn.batch_shape, mvn.event_shape
(60, 60) # X.shape
(60, 60, 1, 2) # idx.shape == X.shape + (1 "broadcast against batch", 2 "event")
(60, 60, 2) # probs.shape == X.shape + (2 "mvn batch shape")
推荐阅读
- pandas - 使用列表理解迭代熊猫数据框中的列
- xslt - 我们如何将一个 xsl 中的值导入 XSLT1.0 中的另一个 xsl
- javascript - 为什么发射器@current-items 不适用于 v-data-table?
- token - HID 活动身份钥匙串令牌已锁定
- javascript - 尝试从 Bootstrap 中的克隆的下拉选项中获取“保存”并将输入保存到本地存储
- pytorch - CNN 的输出应该是图像
- gatsby - 使用 gatsby-image 的 tracedSVG 作为实际图片
- python-3.x - 用lxml读取CDATA,行尾问题
- shell - 跳过节点时的 shell 迭代
- vue.js - vue组件返回值