deep-learning - 关于 Softmax 函数作为预测中的输出层
问题描述
我知道softmax激活函数:输出层与softmax激活的总和总是等于1,也就是说:输出向量是标准化的,这也是必要的,因为最大累积概率不能超过1。好的,这很清楚。
但是我的问题是:当softmax用作分类器时,是使用argmax函数来获取类的索引。那么,如果重要参数是获得正确类别的指标,那么获得一个或更高的累积概率有什么区别?
python中的一个示例,我在其中制作了另一个softmax(实际上不是softmax函数),但分类器的工作方式与具有真正softmax函数的分类器的工作方式相同:
import numpy as np
classes = 10
classes_list = ['dog', 'cat', 'monkey', 'butterfly', 'donkey',
'horse', 'human', 'car', 'table', 'bottle']
# This simulates and NN with her weights and the previous
# layer with a ReLU activation
a = np.random.normal(0, 0.5, (classes,512)) # Output from previous layer
w = np.random.normal(0, 0.5, (512,1)) # weights
b = np.random.normal(0, 0.5, (classes,1)) # bias
# correct solution:
def softmax(a, w, b):
a = np.maximum(a, 0) # ReLU simulation
x = np.matmul(a, w) + b
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum(axis=0), np.argsort(e_x.flatten())[::-1]
# approx solution (probability is upper than one):
def softmax_app(a, w, b):
a = np.maximum(a, 0) # ReLU simulation
w_exp = np.exp(w)
coef = np.sum(w_exp)
matmul = np.exp(np.matmul(a,w) + b)
res = matmul / coef
return res, np.argsort(res.flatten())[::-1]
teor = softmax(a, w, b)
approx = softmax_app(a, w, b)
class_teor = classes_list[teor[-1][0]]
class_approx = classes_list[approx[-1][0]]
print(np.array_equal(teor[-1], approx[-1]))
print(class_teor == class_approx)
两种方法之间获得的类总是相同的(我说的是预测,而不是训练)。我问这个是因为我在 FPGA 设备中实现 softmax 并且使用第二种方法不需要 2 次运行来计算 softmax 函数:首先找到指数矩阵及其和,然后执行除法。
解决方案
让我们回顾一下 的用途softmax
:
你应该使用
softmax
if:- 您正在训练NN,并希望在训练期间限制输出值的范围(您可以改用其他激活函数)。这可以稍微有助于剪裁渐变。
- 您正在对 NN 执行推理,并且希望获得关于分类结果“置信度”的度量(范围为 0-1)。
- 您正在对 NN 执行推理并希望得到
top K
结果。在这种情况下,建议使用“置信度”指标来比较它们。 - 您正在对几种 NN(集成方法)进行推理,并希望将它们平均化(否则它们的结果将不容易比较)。
softmax
如果出现以下情况,则不应使用(或删除):- 您正在对 NN 执行推理,并且您只关心顶级类。请注意,NN 可以使用 Softmax 进行训练(以获得更好的准确性、更快的收敛等)。
在您的情况下,您的见解是正确的:Softmax
如果您的问题只需要您在推理阶段获得最大值的索引,那么最后一层的激活函数是没有意义的。此外,由于您的目标是 FPGA 实现,这只会让您更加头疼。
推荐阅读
- java - 查找 Checkbox id android 时出错
- php - Windows 10 中的 cron 权限和 xampp 设置
- css - 为什么我的 CSS 链接需要 ?vh=# 才能在我的网站上更新?
- mysql - SQL Join 表 A 上的重复值
- python - 如何在pdfplumber中打开多个文件?
- python - 需要帮助创建复利计算器
- c++ - 无法使用 Visual Studio 2019 中的模块链接功能时间
- python-3.x - 如何将十六进制值 81869400 转换为十进制值 948681000
- python - 从 NetCDF 中的多个纬度中心查找半径内的值
- javascript - Redux Store:根据下拉菜单将值输入到存储中