python - 如何获得张量/矩阵中列正元素的中位数?
问题描述
具体给定一个二维矩阵,如何找到每列正元素的中位数?
从数学上讲:返回 B,其中B[i] = median({A[j, i] | A[j, i] > 0})
我知道中位数可以通过计算 tf.contrib.distributions.percentile
tf.boolean_mask(A, tf.greater(A, 0))
输出一维列表而不是矩阵。
解决方案
tf.boolean_mask()
确实返回一维张量,否则保留尺寸的结果张量将是稀疏的(cf 列具有不同数量的正元素)。
由于我不知道稀疏矩阵的任何中值函数,因此想到的唯一替代方法是遍历列,例如使用tf.map_fn()
:
import tensorflow as tf
A = tf.convert_to_tensor([[ 1, 0, 20, 5],
[-1, 1, 10, 0],
[-2, 1, -10, 2],
[ 0, 2, 20, 1]])
positive_median_fn = lambda x: tf.contrib.distributions.percentile(tf.boolean_mask(x, tf.greater(x, 0)), q=50)
A_t = tf.matrix_transpose(A) # tf.map_fn is applied along 1st dim, so we need to transpose A
res = tf.map_fn(fn=positive_median_fn, elems=A_t)
with tf.Session() as sess:
print(sess.run(res))
# [ 1 1 20 2]
注意:此代码段不涵盖列不包含正元素的情况。tf.contrib.distributions.percentile()
如果其输入张量为空,将返回错误。tf.boolean_mask(x, tf.greater(x, 0))
例如,可以使用关于形状的条件(例如 with tf.where()
)
推荐阅读
- python - 监听简单表上的插入/更新事件
- airflow - 气流 UI 未显示所有 DAG
- java - 返回不同的值取决于倍值
- java - JVM在使用synchronized时如何保证被引用对象中成员变量修改的可见性?
- xml - 是否可以使用 ssis 将 XML 文件上传到 SQL 服务器
- firebase - 添加新应用时,如何在 Firebase 控制台中解决此错误?“处理请求时出现未知错误。再试一次。”
- android - 图层列表内的可绘制矢量正在拉伸
- c# - 加入并在外来对象/键上获得不同 Linq
- java - 如果消息不包含消息选择器条件,为什么消费者会停止?
- angular - 没有导出的成员“国家”。- 角