python - tensorflow / keras中批量大小的自定义损失w权重数组
问题描述
我正在创建一个自定义损失函数,它是一个 MAE( y_true , y_pred ),由两个数组a和b加权,其中所有四个数组的大小相同(10000 个样本/时间步长)。
def custom_loss(y_true, y_pred, a, b):
mae = K.abs(y_true - y_pred)
loss = mae * a * b
return loss
问题:如何将a和b输入到函数中?两者都应该像 y_true 和 y_pred 一样被拆分和洗牌。
到目前为止,我正在使用一个 LSTM 训练的数据 X 形状(样本 x 时间步长 x 变量)。在这里,我尝试了 tf 的add_loss函数来完成这项工作,当将a和b作为进一步的输入层传递时,由于数据形状不同而导致错误。
#LSTM
input_layer = Input(shape=input_shape)
in = LSTM(20, activation='relu', return_sequences=True)(input_layer)
out = LSTM(1, activation='linear', return_sequences=False)(in)
layer_a = Input(shape=(10000))
layer_b = Input(shape=(10000))
model = Model(inputs = [input_layer, layer_a, layer_b], outputs = out)
model.add_loss(custom_loss(input_layer, out, layer_a, layer_b))
model.compile(loss=None, optimizer=Adam(0.01))
# X=data of shape 20 variables x 10000 timesteps, y, a, b = data of shape 10000 timesteps
model.fit(x=[X, a, b], y=y, batch_size=1, shuffle=True)
我该如何正确地做到这一点?
解决方案
如果您只需要计算损失函数,那么我会在您的自定义损失函数周围编写一个包装器,并传递一个元组a
作为您的标签。b
(y,a,b)
像这样的东西:
n_sample = 100
timesteps = 30
features = 5
X = np.random.uniform(0,1, (n_sample,timesteps,features))
y = np.random.uniform(0,1, n_sample)
a = np.random.uniform(0,1, n_sample)
b = np.random.uniform(0,1, n_sample)
def custom_loss_wrapper(y_true, y_pred):
def custom_loss(y_true, y_pred, a, b):
mae = K.abs(y_true - y_pred)
loss = mae * a * b
return loss
return custom_loss(y_true[0], y_pred, y_true[1], y_true[2])
input_layer = Input(shape=(timesteps, features))
x = LSTM(20, activation='relu', return_sequences=True)(input_layer)
out = LSTM(1, activation='linear')(x)
model = Model(inputs =input_layer, outputs = out)
model.compile(loss=custom_loss_wrapper, optimizer=Adam(0.01))
model.fit(x=X, y=(y,a,b), shuffle=True, epochs=3)
它简化了网络架构并消除了不必要layer_a
的layer_b
推理时间。
推荐阅读
- node.js - 当我尝试从 HomeScreen 打开产品时出现以下错误 - 将循环结构转换为 JSON -->
- excel - 如何循环遍历 Excel 行
- symfony - 以 symfony 形式预设集合,只保存不清空
- python - 如何将列表转换为字典并为键添加随机值
- c# - 无法从程序集“Microsoft.AspNetCore.Mvc.Formatters.Json”加载类型“Microsoft.AspNetCore.Mvc.MvcJsonOptions”,
- npm - 如何在私有 npm 注册表上托管 TS 声明文件包
- database - 为什么函数 weekOfYear(2019.12.31) 返回 1?
- javascript - 增加日期并在日期字段中指定为最大值
- javascript - onClick 不适用于自定义创建的元素
- javascript - 如何在没有表单的情况下使我的重置按钮在 HTML 中工作?