首页 > 解决方案 > Tensorflow:根据布尔掩码重塑张量

问题描述

我有一个一维张量值: a = tf.constant([0.1, 0.2, 0.3, 0.4])

和一个 nD 布尔掩码: b = tf.constant([[1, 1, 0], [0, 1, 1]])

b 中 1 的总数与 a 的长度相匹配。

如何从 a 和 b 获得 [[0.1, 0.2, 0.0], [0.0, 0.3, 0.4]]?

标签: tensorflow

解决方案


import tensorflow as tf

a = tf.constant([0.1, 0.2, 0.3, 0.4])

b = tf.constant([[1, 1, 0], [0, 1, 1]])

# reshape b to a 1D vector
b_res = tf.reshape(b, [-1])
# Get the indices to gather using cumsum
b_cum = tf.cumsum(b_res) - 1

# Gather the elements, multiply by b_res to zero out the unwanted values and reshape back
c = tf.reshape(tf.gather(a, b_cum) * tf.cast(b_res, 'float32'), [-1, 3])
print(c)

推荐阅读