python - 从 one-hot 表示到标签
问题描述
我的预测在一个张量之下pred
,并且pred.shape
是(4254, 10, 3)
。所以我们有4254
维度矩阵(10, 3)
。让我们来看看其中一个矩阵。
W = array([[0.04592975, 0.09632163, 0.85774857],
[0.03408821, 0.27141285, 0.6944989 ],
[0.02538731, 0.4691383 , 0.50547445],
[0.01959289, 0.6456455 , 0.33476162],
[0.01333424, 0.7494791 , 0.23718661],
[0.0109237 , 0.77042925, 0.218647 ],
[0.01438793, 0.7796771 , 0.20593494],
[0.01474626, 0.6817438 , 0.30350992],
[0.02189695, 0.57687664, 0.40122634],
[0.03810155, 0.5130332 , 0.44886518]], dtype=float32)
正如您在上面的示例中看到的那样,有 10 个向量表示标签的 one-hot 表示。例如,np.argmax([0.04592975, 0.09632163, 0.85774857]) = 2
.
为什么我要批量处理 10 个向量?我正在研究一个时间序列预测问题,t_0
我有时会预测接下来的 10 个标签。t_1
t_10
对于这些矩阵中的每一个,我都会有兴趣取回原始标签。所以对于矩阵W
,我应该得到数组 array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1])
。
让我们定义阈值数组threshold_array = np.array([0.6, 0.65, 0.70, 0.75, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80])
并收回labels = array([2, 2, 2, 1, 1, 1, 1, 1, 1, 1])
。假设中性位置是1
并且动作是0
或2
。这里的目标是labels
根据threshold_array
和我们的矩阵进行修改W
。
如果我采取W[0]
,我们知道np.argmax(W[0]) = 2
并且W[0][2] = 0.85774857
。那样W[0][2] >= threshold_array[0]
的话,labels[0]
就会留下来2
。
另一个例子有点不同。如果我采取W[2]
,我们知道np.argmax(W[2]) = 2
并且W[2][2] = 0.50547445
。As W[2][2] < threshold_array[2]
, thenlabels[2]
将从2
变为0
。
如果我将该策略应用于 中的每个向量W
,labels
则现在设置为array([2, 2, 0, 1, 1, 1, 1, 1, 1, 1])
。请注意,只有一个动作可以成为中立位置,而不是相反。
如何在 python 中将该策略编码到W
内部的每个矩阵pred
以获得维度的标签矩阵(4254, 10)
?
解决方案
我不确定这是处理这个问题的最佳方法,但这里有一个答案。
import numpy as np
threshold_array = np.array([0.6, 0.65, 0.70, 0.75, 0.80, 0.80, 0.80, 0.80, 0.80, 0.80])
def get_labels(W, threshold_array):
labels = []
for i, vect in enumerate(W):
neutral_position = 1
label = np.argmax(vect)
if label in [0, 2]:
if vect[label] < threshold_array[i]:
labels.append(neutral_position)
else:
labels.append(label)
else:
labels.append(label)
return np.array(labels)
if __name__ == "__main__":
labels = []
for matrix in pred:
labels.append(get_labels(matrix, theshold_array))
labels = np.array(labels)
推荐阅读
- flutter - 为什么 Ad Mob 在 Flutter 应用中不显示广告?
- xml - 使用 xpath 从 BaseX 中选择所有标题
- c++ - 如何从/向二进制文件C++读取和写入具有动态数组成员的类
- c# - 以编程方式更改 xaml 的内容以更改控件的属性
- javascript - 如何访问 XMLHttpRequest 中的类方法?
- python - 对特定列中二级索引的最后一行中的每个一级索引求和
- android - 如何重新创建 Android 应用程序类?
- java - ATG baseline_update.sh 失败并出现错误:没有数据记录输出
- python - 避免重复将数据帧传递给递归函数?
- javascript - jQuery验证不适用于Webforms表单