python - 在 keras 中加载具有自定义损失的模型(缺少成员)
问题描述
我是 Keras 的新手,并检查了许多与负载模型相关的问题,但没有一个问题(例如eg1 eg2 } 让我解决了我的问题。
很抱歉这篇文章很长,但我想提供尽可能多的数据来帮助您重现错误
我在 google colab 中运行代码
我有一个具有以下自定义损失函数的模型:
def wasserstein_loss(y_true, y_pred):
return K.mean(y_true * y_pred)
def gradient_penalty_loss(y_true, y_pred, averaged_samples, gradient_penalty_weight):
gradients = K.gradients(y_pred, averaged_samples)[0]
gradients_sqr = K.square(gradients)
gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
gradient_penalty = gradient_penalty_weight * K.square(1 -
gradient_l2_norm)
return K.mean(gradient_penalty)
partial_gp_loss = partial(gradient_penalty_loss,
averaged_samples=averaged_samples,
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
partial_gp_loss.__name__ = 'gradient_penalty' # Functions need names or Keras will throw an error
使用损失函数:
discriminator_model = Model(inputs=[real_samples, generator_input_for_discriminator],
outputs=[discriminator_output_from_real_samples,discriminator_output_from_generator,averaged_samples_out])
discriminator_model.compile(optimizer=Adam(0.0001, beta_1=0.5, beta_2=0.9),
loss=[wasserstein_loss,
wasserstein_loss,
partial_gp_loss])
我保存到模型的方式:
discriminator_model.save('D_' + str(epoch) + '.h5')
generator_model.save('G_' + str(epoch) + '.h5')
我加载模型的方式:
generator_model = models.load_model(Gh5files[-1],custom_objects={'wasserstein_loss': wasserstein_loss})
discriminator_model = models.load_model(Dh5files[-1],custom_objects={'wasserstein_loss': wasserstein_loss ,
'RandomWeightedAverage': RandomWeightedAverage ,
'gradient_penalty':partial_gp_loss(gradient_penalty_loss,
averaged_samples=averaged_samples,
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
})
不,当我尝试上传保存的模型时,我收到以下错误
Loading pretrained models
about to load follwoing files: ./G_31.h5 ./D_31.h5
/usr/local/lib/python3.6/dist-packages/keras/engine/saving.py:327: UserWarning: Error in loading the saved optimizer state. As a result, your model is starting with a freshly initialized optimizer.
warnings.warn('Error in loading the saved optimizer '
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-30-5ed3e08a8fce> in <module>()
12 'gradient_penalty':partial_gp_loss(gradient_penalty_loss,
13 averaged_samples=averaged_samples,
---> 14 gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
15 })
16
TypeError: gradient_penalty_loss() missing 1 required positional argument: 'y_pred'
我错过了什么,我该如何介绍 y_pred ?
解决方案
Keras 自定义损失函数的格式必须为my_loss_function(y_true, y_pred)
. 您的gradient_penalty_loss
函数无效,因为它有附加参数。
正确的方法如下:
def get_gradient_penalty_loss(averaged_samples, gradient_penalty_weight):
def gradient_penalty_loss(y_true, y_pred):
gradients = K.gradients(y_pred, averaged_samples)[0]
gradients_sqr = K.square(gradients)
gradients_sqr_sum = K.sum(gradients_sqr,
axis=np.arange(1, len(gradients_sqr.shape)))
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
gradient_penalty = gradient_penalty_weight * K.square(1 -
gradient_l2_norm)
return K.mean(gradient_penalty)
return gradient_penalty_loss
gradient_penalty_loss= get_gradient_penalty_loss(
gradient_penalty_loss,
averaged_samples=averaged_samples,
gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
然后通过models.load_model(..., custom_objects={'gradient_penalty_loss':gradient_penalty_loss})
看起来您可能正在尝试使用该partial
函数执行类似的操作,但是由于您尚未定义它,所以我不知道是否是这种情况。
无论哪种方式,还有一个问题是您正在调用partial_gp_loss = partial(...)
which returns gradient_penalty_loss
。然后,当你加载你调用的模型时partial_gp_loss(...)
,但此时你应该调用任何东西,你应该只是传递函数!
您收到错误TypeError: gradient_penalty_loss() missing 1 required positional argument: 'y_pred'
是因为此时您正在尝试执行gradient_penalty_loss
并且您将其两个命名参数传递给它(averaged_samples
和gradient_penalty_weight
),除了传递一个位置参数(转到y_true
)并寻找第二个位置参数,y_pred
它不见了。
推荐阅读
- python - Python 通过使用 pd.groupby 查找学生评估的数量并插入每个评估的分数
- transactions - 带有查询生成器的 TypeORM 事务
- xamarin.forms - 找不到 Xamarin.forms 中的 CarouselView
- iphone - 苹果手机连接电脑,电脑无法识别设备
- java - 使用 Java 访问 Windows 中的环境变量失败
- javascript - JavaScript 中的 do/while 混淆
- javascript - Next.js API 将 FormData 转发到外部 API
- swift - Swift中大括号后跟类名的语法含义
- sql - 有没有办法在不使用子查询的情况下根据不同的行计算平均值?
- algorithm - 如何通过替换循环来最小化图的顶点?