python - Softmax 不会导致 Python 实现中的概率分布
问题描述
我有一个简单的 softmax 实现:
softmax = np.exp(x) / np.sum(np.exp(x), axis=0)
对于 x 在此处设置为数组:https ://justpaste.it/6wis7
您可以将其加载为:
import numpy as np
x = np.as (just copy and paste the content (starting from array))
我得到:
softmax.mean(axis=0).shape
(100,) # now all elements must be 1.0 here, since its a probability
softmax.mean(axis=0) # all elements are not 1
array([0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158,
0.05263158, 0.05263158, 0.05263158, 0.05263158, 0.05263158])
为什么这个实现是错误的?如何解决?
解决方案
在我看来很好:
import numpy as np
def softmax(x):
return np.exp(x) / np.sum(np.exp(x), axis=0)
logits = softmax(np.random.rand(4))
print(logits)
softmax 动作的所有元素的总和应该等于 1。
对于分类任务,通常采用具有最高值np.argmax()
(
class_index = np.argmax(logits) # Assuming logits is the output of a trained model
print('Most likely class: %d' % class_index)
正如 JosepJoestar 在评论中指出的那样,softmax 函数的定义可以在这里找到。
推荐阅读
- scala - Sbt 包装脚本
- webstorm - WebStorm - 在每次提交之前删除 console.log
- python - python中是否有一个函数可以将散点图拟合到sin函数?
- javascript - 尝试将 Magicmouse.js 添加到自定义 wordpress 主题 - 无法加载
- php - 在 php Cpanel 中管道恢复电子邮件
- python - 在特定时间间隔内删除行
- c++ - 如何将小数规范化为范围内的值
- c# - 如何测试动态 JSON 文件中是否存在属性
- postgresql - plpgsql循环使用多行
- html - 共享具有元 Open Graph 标签的音频 URL 后,iMessage 中未显示音频播放按钮