python - 将带有自定义度量函数的 keras 配置保存到 JSON
问题描述
我正在尝试为 Keras 模型保存我的配置。我希望能够从文件中读取配置,以便能够重现培训。
在函数中实现自定义指标之前,我可以按照下面显示的方式进行操作,而无需mean_pred
. 现在我遇到了问题TypeError: Object of type 'function' is not JSON serializable
。
在这里我读到可以将函数名作为字符串获取custom_metric_name = mean_pred.__name__
。我不仅希望能够保存名称,而且如果可能的话,还希望能够保存对函数的引用。
也许我应该像这里提到的那样考虑不仅将我的配置存储在 .py 文件中,而且使用ConfigObj
. 除非这能解决我当前的问题,否则我稍后会实施。
问题的最小工作示例:
import keras.backend as K
import json
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
config = {'epochs':500,
'loss':{'class':'categorical_crossentropy'},
'optimizer':'Adam',
'metrics':{'class':['accuracy', mean_pred]}
}
# Do the training etc...
config_filename = 'config.txt'
with open(config_filename, 'w') as f:
f.write(json.dumps(config))
非常感谢有关此问题的帮助以及以最佳方式保存我的配置的其他方法。
解决方案
为了解决我的问题,我将函数的名称保存为配置文件中的字符串,然后从字典中提取函数以将其用作模型中的指标。还可以使用:'class':['accuracy', mean_pred.__name__]
将函数的名称保存为配置中的字符串。这也适用于多个自定义函数和更多指标键(例如,在进行回归和分类时为“reg”定义指标,如“class”)。
import keras.backend as K
import json
from collections import defaultdict
def mean_pred(y_true, y_pred):
return K.mean(y_pred)
config = {'epochs':500,
'loss':{'class':'categorical_crossentropy'},
'optimizer':'Adam',
'metrics':{'class':['accuracy', 'mean_pred']}
}
custom_metrics= {'mean_pred':mean_pred}
metrics = defaultdict(list)
for metric_type, metric_functions in config['metrics'].items():
for function in metric_functions:
if function in custom_metrics.keys():
metrics[metric_type].append(custom_metrics[function])
else:
metrics[metric_type].append(function)
# Do the training, use metrics
config_filename = 'config.txt'
with open(config_filename, 'w') as f:
f.write(json.dumps(config))
推荐阅读
- python - 如何将 NaN(在 pandas df Float 列中)显示为空单元格?
- php - 我无法汇总个别学生的数据
- algorithm - 6个nxn阶矩阵相乘的时间复杂度是多少?
- python - 为什么我需要在调用函数后分配这个变量?
- c# - .NET Core 2.2 Shared Cookie 在登录时导致 Bad Request 错误
- swift - 错误:尝试转换时无法将“Ninjumper.GameScene”类型的值转换为“SKSpriteNode”
- c# - 如何管理与用户无关信息的 Web 服务器场中的状态
- java - 这个 ObjectWeb 错误是什么意思:“不支持的类文件主要版本 56”,我该如何解决?
- django - 关于基本用户模型 Django 的查询集问题
- jdbc - 与 Kerberos 的 JDBC 连接启用了 Apache phoenix