python - 使用 Numpy/Tensorflow 快速将索引变成掩码
问题描述
我正在尝试实现一个随机神经网络,其中我从我的输出层的 softmax 分布中采样一个掩码,以获取样本来训练一个无监督的目标函数。我希望我的采样蒙版仅在我采样的点上打开。(即如果输出为 [0.1, 0.1, 0.1, 0.7] 并且采样索引为 3,我希望我的输出为 [0, 0, 0, 1])。我将如何在批次中有效地做到这一点?
这是我目前正在做的事情,它很慢但效果很好。:
def Stochastic_softmax_selection(M, K, N_rf, N=1000):
def select(y_dist, G, lossfn, reg=user_constraint):
# y_dist: shape(Batchsize, M*K)
# initialize an empty mask
mask = np.zeros((y_dist.shape[0], N, M*K))
# sample from the distribution
sam = tf.random.categorical(y_dist[:, :], N) # has shape of (batchsize, N)
# iterate over a batch
for batch in range(y_dist.shape[0]):
# select N sample points
for n in range(N):
mask[batch, n, sam[batch, n]] = 1.0
mask = tf.constant(mask, dtype=tf.float32)
return mask
return select
解决方案
试试这个:
input = tf.random.uniform([10, 20])
x = tf.math.argmax(input, -1)
x = tf.one_hot(x, tf.shape(input)[-1])
推荐阅读
- flutter - Flutter - 临时错误 - IconButton 小部件需要 Material 小部件祖先
- c# - 如何在运行时在 Start() 或可在 Update() 中使用的唤醒方法上设置变量?
- flutter - 关于颤振中尖括号使用的困惑
- azure - 无法在同一 arm 模板中为密钥库引用用户分配身份的 principalId
- google-apps-script - 基于 Google 电子表格中日期的 Slack 消息
- r - 简化嵌套数据框的列表
- c# - 我如何找出对撞机的层?
- hyperledger-fabric - 在所有订购者都关闭的情况下对等点的预期行为
- swift - var 中的 Swift 函数
- python - 将一列的值与另一列的所有行进行比较