python - 使用经过训练的模型进行预测时,如何测量每个节点的激活水平?
问题描述
我很感兴趣,当一些输入输入到训练模型时,哪些节点被激活以及激活的强度有多大。
下图显示了我想从模型中得到什么。(我想知道每个节点的激活程度)
据我所知,有一些技术可以可视化节点(或过滤器)正在关注的内容。(尤其是在 CNN 中)
有什么好方法可以衡量每个节点的活跃程度吗?
我通常使用 Keras。但是 pyTorch 也可以。
解决方案
您正在寻找激活图/grad cam ...,您可以查看:https ://keras.io/examples/vision/grad_cam/和https://keras.io/examples/vision/visualizing_what_convnets_learn/
您也可以尝试keract:https ://github.com/philipperemy/keract
还有其他关于 AI 领域可解释性主题的 github 存储库,例如:https ://github.com/XAI-ANITI/ethik
或者,您可以在 keras 和 pytorch 中自己完成,例如通过在 pytorch 中注册一个钩子:
def get_feat_vector(self, img, model):
with torch.no_grad():
my_output = None
def my_hook(module_, input_, output_):
nonlocal my_output
my_output = output_
a_hook = model.layers[0].register_forward_hook(my_hook)
model(img)
a_hook.remove()
return my_output
...
for element in val_dataloader:
model.eval()
feature_vect = get_feat_vector(element[0].float().cuda(), model)
推荐阅读
- javascript - React 测试:是否应该测试文本内容?
- highcharts - 适合长 Y 轴标签的 HighCharts
- python - 在 Pandas 数据框中创建重复值索引
- python - 从 json 获取价值 - API
- spyder - Spyder 可以在悬停时显示对象文档吗?
- android - 用户更改语言时如何更改fontFamily?
- r - 工作区无法加载到服务器中,文件具有幻数“RDX3”
- r - 将特定数值列除以对应于另一个因子列的最大观察值
- node.js - 在我的网站上注册 serviceWorker 失败
- javascript - 轮播右键单击按钮无法正常工作