python - 预处理准确度指标
问题描述
我有一个预测 5 个类别的模型。我想更改准确度指标,如下例所示:
def accuracy(y_pred,y_true):
#our pred tensor
y_pred = [ [0,0,0,0,1], [0,1,0,0,0], [0,0,0,1,0], [1,0,0,0,0], [0,0,1,0,0]]
# make some manipulations with tensor y_pred
# actons description :
for array in y_pred :
if array[3] == 1 :
array[3] = 0
array[0] = 1
if array[4] == 1 :
array[4] = 0
array[1] = 1
else :
continue
#this nice work with arrays but howe can i implement it with tensors ?
#after manipulations result->
y_pred = [ [0,1,0,0,0], [0,1,0,0,0], [1,0,0,0,0], [1,0,0,0,0],[0,0,1,0,0] ]
#the same ations i want to do with y_true
# and after it i want to run this preprocess tensors the same way as simple tf.keras.metrics.Accuracy metric
我认为 tf.where 可以帮助过滤张量,但不幸的是不能正确地做到这一点。
如何使用张量制作这个预处理精度指标?
解决方案
如果要将这些向左移动 3 个索引,可以执行以下操作:
import numpy as np
y_pred = [ [0,0,0,0,1], [0,1,0,0,0], [0,0,0,1,0], [1,0,0,0,0], [0,0,1,0,0]]
y_pred = np.array(y_pred)
print(y_pred)
shift = 3
one_pos = np.where(y_pred==1)[1] # indices where the y_pred is 1
# updating the new positions with 1
y_pred[range(y_pred.shape[1]),one_pos - shift] = np.ones((y_pred.shape[1],))
# making the old positions zero
y_pred[range(y_pred.shape[1]),one_pos] = np.zeros((y_pred.shape[1],))
print(y_pred)
[[0 0 0 0 1]
[0 1 0 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 1 0 0]]
[[0 1 0 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 1 0 0]
[0 0 0 0 1]]
更新:
如果您只想移动索引 3 和 4。
import numpy as np
y_pred = [ [0,0,0,0,1], [0,1,0,0,0], [0,0,0,1,0], [1,0,0,0,0], [0,0,1,0,0]]
y_pred = np.array(y_pred)
print(y_pred)
shift = 3
one_pos = np.where(y_pred==1)[1]# indices where the y_pred is 1
print(one_pos)
y_pred[range(y_pred.shape[1]),one_pos - shift] = [1 if (i == 3 or i == 4) else 0 for i in one_pos]
y_pred[range(y_pred.shape[1]),one_pos] = [0 if (i == 3 or i == 4) else 1 for i in one_pos]
print(y_pred)
[[0 0 0 0 1]
[0 1 0 0 0]
[0 0 0 1 0]
[1 0 0 0 0]
[0 0 1 0 0]]
[4 1 3 0 2]
[[0 1 0 0 0]
[0 1 0 0 0]
[1 0 0 0 0]
[1 0 0 0 0]
[0 0 1 0 0]]
推荐阅读
- sql - 批量插入数据库,跳过数据库中发现的重复项
- asp.net - 提交表单后,我在控制器中收到空值
- python - PyCharm 上的 virualenv 从 Linux 到 Windows
- c++ - 在 VS2019 解决方案中包含 MRPT 库时出现错误 C2039
- clojure - 如何根据 uri 在 reframe 中提供正确的面板?
- snowflake-cloud-data-platform - Snowflake 中的值对:变体还是对象?
- flutter - 在 Flutter 中合并流
- google-apps-script - 如何使用归档脚本将 Importrange 单元格转换为硬值,而无需将源设置为公共共享
- r - Shiny 中的可下载表格
- ruby-on-rails - Selenium 在 Heroku 上不起作用(Selenium::WebDriver::Error::SessionNotCreatedError)