首页 > 解决方案 > 从张量流中的分布中采样非大小张量

问题描述

以下代码:

import tensorflow as tf
tfd = tf.contrib.distributions

mean = [0.0, 0.0]
scale = [1.0, 1.0]

dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=scale)
samp = dist.sample([None])

给出错误:

TypeError: Expected int32, got None of type '_Message' instead.

但是如果将 None 替换为整数 n,则从分布中生成 n 个样本。有没有办法从分布中获取未知数量的样本?

编辑:最初的问题可能措辞不当;我想对形状 (None, ...) 的张量进行采样,以与这种形状的其他张量结合。很明显,需要在某个地方输入来在运行时修复大小。

标签: tensorflow

解决方案


你可以做

num_samples = tf.placeholder(dtype=tf.int32, shape=())
sampl = dist.sample(num_samples)

然后输入样本数。同样,如果您有一个表示样本数量的标量张量,您可以将其传入。


推荐阅读