machine-learning - 将 sample_weights 与 fit_generator() 一起使用
问题描述
在自回归连续问题中,当零点过多时,可以将这种情况视为零膨胀问题(即 ZIB)。换句话说,f(x)
我们想要拟合我们想要逼近g(x)*f(x)
的f(x)
函数,而不是拟合 ,即y
,并且g(x)
是一个输出0到1之间的值的函数,具体取决于值是零还是非零。
目前,我有两个模型。一个给我g(x)
的模型和另一个适合的模型g(x)*f(x)
。
第一个模型给了我一组权重。这是我需要你帮助的地方。我可以将sample_weights
参数与model.fit()
. 当我处理大量数据时,我需要使用model.fit_generator()
. 不过,fit_generator()
没有论据sample_weights
。
有什么办法可以在sample_weights
里面工作fit_generator()
吗?g(x)*f(x)
否则,知道我已经有一个训练有素的模型,我怎么能适应g(x)
?
解决方案
您可以提供样本权重作为生成器返回的元组的第三个元素。来自 Keras 文档fit_generator
:
生成器:生成器或
Sequence
(keras.utils.Sequence
) 对象的实例,以避免在使用多处理时出现重复数据。生成器的输出必须是
- 一个元组
(inputs, targets)
- 一个元组
(inputs, targets, sample_weights)
。
更新:这是一个生成器的粗略草图,它返回输入样本和目标以及从模型获得的样本权重g(x)
:
def gen(args):
while True:
for i in range(num_batches):
# get the i-th batch data
inputs = ...
targets = ...
# get the sample weights
weights = g.predict(inputs)
yield inputs, targets, weights
model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)
推荐阅读
- highcharts - 如何使用 json 数组对象显示向下钻取列高图 - 3 级向下钻取列高图
- compilation - 如何使用 afl-gcc 编译 openssl
- docker - 502 Bad Gateway nginx/1.13.12 on localhost 同时在同上创建新策略
- django - 购物车 ID 是否应该与用户 ID 绑定?
- wordpress - 如何为 acf_register_block() 实现 JS 回调
- ubuntu - 如何通过 Xdebug - Magento 2 在 PhpStorm 中调试完整项目
- javascript - 为什么将 .map() 的结果插入 JSX 会导致错误?
- python - 将 Json 放入 DynamoDB 项目
- outlook-web-addins - Outlook Web 加载项 > 控制菜单分隔符和无图标条目
- javascript - 如何在一定时间后删除不和谐机器人的角色?(Javascript)