tensorflow - Keras 自定义损失函数
问题描述
我想实现以下自定义损失函数,参数x
作为最后一层的输出。到目前为止,我将这个功能实现为Lambda
层,再加上 kerasmae
损失,但我不再想要那个了
def GMM_UNC2(self, x):
tmp = self.create_mr(x) # get mr series
mr = k.sum(tmp, axis=1) # sum over time
tmp = k.square((1/self.T_i) * mr)
tmp = k.dot(tmp, k.transpose(self.T_i))
tmp = (1/(self.T * self.N)) * tmp
f = self.create_factor(x) # get factor
std = k.std(f)
mu = k.mean(f)
tmp = tmp + std/mu
def loss(y_true, y_pred=tmp):
return k.abs(y_true-y_pred)
return loss
self.y_true = np.zeros((1,1))
self.sdf_net = Model(inputs=[self.in_ma, self.in_mi, self.in_re, self.in_si], outputs=w)
self.sdf_net.compile(optimizer=self.optimizer, loss=self.GMM_UNC2(w))
self.sdf_net.fit([self.macro, self.micro, self.R, self.R_sign], self.y_true, epochs=epochs, verbose=1)
代码实际运行,但实际上并没有tmp
用作损失的输入(我将它乘以某个数字,但损失保持不变)
我究竟做错了什么?
解决方案
如果您想将GMM_UNC2
函数应用于预测,或者仅应用一次来构建损失,您的问题并不完全清楚。如果它是第一个选项,那么所有代码都应该在损失内并应用它y_pred
,比如
def GMM_UNC2(self):
def loss(y_true, y_pred):
tmp = self.create_mr(y_pred) # get mr series
mr = k.sum(tmp, axis=1) # sum over time
tmp = k.square((1/self.T_i) * mr)
tmp = k.dot(tmp, k.transpose(self.T_i))
tmp = (1/(self.T * self.N)) * tmp
f = self.create_factor(x) # get factor
std = k.std(f)
mu = k.mean(f)
tmp = tmp + std/mu
return k.abs(y_true-y_pred)
return loss
如果是第二种选择,一般来说,在 Python 函数定义中将对象作为默认值传递并不是一个好主意,因为它可以在函数定义中更改。此外,您假设 loss 的第二个参数有一个 name y_pred
,但是当被调用时,它是在没有 name 的情况下完成的,作为一个位置参数。总之,您可以尝试在损失内使用显式比较,例如
def loss(y_true, y_pred):
if y_pred is None:
y_pred = tmp
return k.abs(y_true - y_pred)
如果您喜欢忽略预测并强行使用tmp
,那么您可以忽略y_pred
损失的论点而仅使用tmp
,例如
def loss(y_true, _):
return k.abs(y_true - tmp)
推荐阅读
- dynamics-crm - Dynamics CRM 2016 本地 API 身份验证
- docker - Docker build 与主机共享数据
- multithreading - 如何使用 Qt 多线程进行并行列表处理?
- python - Python dateparser.parse 仅在使用 AWS EC2 时返回 NoneType 对象(运行 Python 3.7.9)
- javascript - 在 Chrome 中的 console.log 中显示整个 JSON 对象,而不使用 JSON.stringify()
- nlp - IBM Watson NLU:如何通过 API Endpoint 确定剩余积分?
- python - 如何订阅所有 SharePoint 网站中的文档更改
- java - 如何将方法从一层映射/更改/实现到另一层?
- json - 提取子元素并将父字段添加到其中
- javascript - 在我终止应用程序之前无法接收 Firebase 消息