python - 如何在 softmax 分数中添加阈值
问题描述
在进行多分类时,通常我得到一个 softmax 分数和下面的预测,
softmax_scores = tf.nn.softmax(logits=self.scores, dim=-1)
prediction=tf.argmax(self.scores, 1, name="predictions")
如果我得到的 softmax_socres 是 .The[0.5,0.2,0.3]
预测是[0]
. 现在我想为0.6
softmax_socres 添加阈值。这意味着这里预期的预测是[4]
其他的。我做了如下
threshold=0.6
self.predictions = tf.argmax(self.scores, 1, name="predictions")
x = tf.constant([num_classes], shape=self.predictions.shape, dtype=tf.int64)
self.predictions1 =tf.where(tf.reduce_max(tf.nn.softmax(logits=self.scores, dim=-1),1)>=threshold,self.predictions,x)
并得到:
File "E:\ai\wide-and-shallow cnn\text_cnn.py", line 102, in __init__
x = tf.constant([num_classes], shape=self.predictions.shape, dtype=tf.int64)
File "E:\Python\Python36\lib\site-packages\tensorflow\python\framework\constant_op.py", line 214, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))
File "E:\Python\Python36\lib\site-packages\tensorflow\python\framework\tensor_util.py", line 430, in make_tensor_proto
if shape is not None and np.prod(shape, dtype=np.int64) == 0:
File "E:\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 2566, in prod
out=out, **kwargs)
File "E:\Python\Python36\lib\site-packages\numpy\core\_methods.py", line 35, in _prod
return umr_prod(a, axis, dtype, out, keepdims)
TypeError: __int__ returned non-int (type NoneType)
它在这个演示中工作。
import tensorflow as tf
import numpy as np
a=tf.constant(np.arange(6),shape=(3,2))
b=tf.reduce_max(a,1)
#c=tf.to_int32(a>3)
c=tf.argmax(a,1)
d=b>=3
f=tf.constant([5],shape=c.shape,dtype=tf.int64)
e=tf.where(d,c,f)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(a.eval(),b.eval(),c.eval(),d.eval(),f.eval(),e.eval())
解决方案
这样做怎么样,使用tf.where
threshold = 0.6
softmax_scores = tf.nn.softmax(logits=self.scores, dim=-1)
other_class_idx = tf.cast(tf.shape(softmax_scores)[0] + 1, tf.int64)
other_class_idx = tf.tile( \
tf.expand_dims(other_class_idx, 0), \
[tf.shape(softmax_scores)[0]] \
)
is_other = tf.reduce_max(tf.cast(softmax_scores > threshold, tf.int8), axis=1)
predictions = tf.where( \
is_other>0, \
tf.argmax(softmax_scores, 1), \
other_class_idx \
) # 4
推荐阅读
- c++ - 如果我在 C++ Targeting Linux 中开发来自 Microsoft 的应用程序,Miscrosft 库好吗?
- excel - 如果满足条件,则将文本从 Excel 单元格复制到单词
- arrays - 去掉图片、js、css url部分
- cassandra - 我应该如何选择大小分层压缩策略的参数?
- docker - 后门反向 shell 在 Docker 用户定义的网桥中不起作用
- html - 如何使用 CSS“nth-child”选择器选择自定义行数?
- python - 如何在不输入手机(或机器人令牌)的情况下使用 Telethon 连接到 Telegram?
- javascript - 仅允许在剃刀文本框中输入印地语数字
- python - 我可以直接使用 InMemoryUploadedFile 将图像上传到 Imgur 吗?
- elasticsearch - ELASTICSEARCH - 两个查询合二为一,输出继承