keras - 有没有办法只保存 Keras 的 softmax 预测中的最大概率?
问题描述
我正在使用tensorflow和keras构建神经网络。问题是我有 40k 类别和 1M 条目用于分类问题,当我尝试使用 softmax 预测每个类别的所有概率时,出现内存错误(我认为是因为它无法保存这么大的数组1M x 40k)。
但是,对我来说,例如知道我的数据集中每个主题的三个最可能的类别就足够了。这可以大大减少数组的大小(1M x 3)。那么有没有什么方法可以只得到 Keras 预测中最有可能的三个类别呢?或者我强制必须保存每个类别的所有概率?
谢谢你们!
解决方案
您应该尝试向您的预测函数添加自定义回调。下面的代码从每个批次(从自动创建的日志中)获取预测,并允许您对它们执行操作。
除了按分数排序之外,您还需要首先获取前 N 个分数的索引,以便能够将单热编码转换回原始标签,但这超出了本问题的范围。如果您使用的是标签编码器,您可以:
- 在回调中执行
inverse_transform
以获取实际标签 - 然后将它们与乐谱一起压缩
- 然后按元组内的分数排序
- 然后提取前 N 个标签
一旦找到了用于排序和获取索引或标签的适当算法,您就可以将结果扩展到外部列表(extend
因为您希望将每个单独的批次添加为一堆单独的列表,而不仅仅是添加一个额外的元素每批次(如append
)),然后将其加入到您的原始测试集。
del logs
最后将清除内存中的所有预测,然后继续下一批。
top2_predictions=[]
class CustomCallback(keras.callbacks.Callback):
def on_predict_batch_end(self, batch, logs=None):
preds = logs['outputs']
<insert indices sorting code, maybe np.argpartition>
top2 = preds[:, :2]
top2_predictions.extend(top2)
del logs
model.predict(test_dataset, callbacks=[CustomCallback()])
推荐阅读
- visual-studio-code - Visual Studio Code:远程 SSH 连接到 Windows Server 2019
- bash - 正则表达式如何扩展?
- javascript - 在反应组件中添加脚本标签不起作用?
- reactjs - 我们是否可以在不使用 Datepicker 的情况下将材质 UI 选择器中的 Calender 组件用作独立组件 - 如果可以,如何?
- excel - xFiledlg As FileDialog 不工作有解决方法吗?
- jquery - 当我在popperjs上单击外部时如何关闭弹出窗口?
- c# - Blazor 服务器端组件中的@attribute [AllowAnonymous] 无效
- json - JSONpath通过过滤另一个值来获取一个值
- javascript - 页面重新加载时使用节点 js 动态更改数据库中的内容
- node.js - 在 nodejs 中的 AWS Lambda 函数之间共享模型的最佳实践是什么?