首页 > 解决方案 > 如何获得张量/矩阵中列正元素的中位数?

问题描述

具体给定一个二维矩阵,如何找到每列正元素的中位数?

从数学上讲:返回 B,其中B[i] = median({A[j, i] | A[j, i] > 0})

我知道中位数可以通过计算  tf.contrib.distributions.percentile

tf.boolean_mask(A, tf.greater(A, 0))输出一维列表而不是矩阵。

标签: pythontensorflowconditional-statementsdistributionmedian

解决方案


tf.boolean_mask()确实返回一维张量,否则保留尺寸的结果张量将是稀疏的(cf 列具有不同数量的正元素)。

由于我不知道稀疏矩阵的任何中值函数,因此想到的唯一替代方法是遍历列,例如使用tf.map_fn()

import tensorflow as tf

A = tf.convert_to_tensor([[ 1, 0,  20, 5],
                          [-1, 1,  10, 0],
                          [-2, 1, -10, 2],
                          [ 0, 2,  20, 1]])


positive_median_fn = lambda x: tf.contrib.distributions.percentile(tf.boolean_mask(x, tf.greater(x, 0)), q=50)
A_t = tf.matrix_transpose(A) # tf.map_fn is applied along 1st dim, so we need to transpose A
res = tf.map_fn(fn=positive_median_fn, elems=A_t)

with tf.Session() as sess:
    print(sess.run(res))
# [ 1  1 20  2]

注意:此代码段不涵盖列不包含正元素的情况。tf.contrib.distributions.percentile()如果其输入张量为空,将返回错误。tf.boolean_mask(x, tf.greater(x, 0))例如,可以使用关于形状的条件(例如 with tf.where()


推荐阅读