python - 在 Keras 自定义损失函数中有效使用 SciPy 函数?
问题描述
我正在尝试通过使用自定义损失函数来提高 keras/tensorflow 递归神经网络 (RNN) 的质量。到目前为止,该模型的使用mean_squared_error
取得了一定的成功,但在我的数据中,每个样本时间序列的峰谷点比中间点更重要(我有多个同等重要的特征)。因此,LMS 方法的同等权重是不够的。目的是引入自定义损失函数,其中峰值和谷点在损失计算中被赋予更高的权重,更突出的点具有更高的权重。
我的数据被限制np.ndarray
为 LSTM 的形状(nSamples,nTimesteps,nFeatures),具有 1000-10000 个样本,60-600 个时间步,具体取决于具有 200 个输出特征的模型(正在训练模型以预测200 个信号)。
型号总结:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 61, 20)] 0
_________________________________________________________________
bidirectional (Bidirectional (None, 61, 16) 4736
_________________________________________________________________
bidirectional_1 (Bidirection (None, 61, 64) 41472
_________________________________________________________________
bidirectional_2 (Bidirection (None, 61, 128) 197632
_________________________________________________________________
time_distributed (TimeDistri (None, 61, 128) 16512
_________________________________________________________________
time_distributed_1 (TimeDist (None, 61, 32) 4128
_________________________________________________________________
time_distributed_2 (TimeDist (None, 61, 217) 7161
=================================================================
Total params: 271,641
Trainable params: 271,641
Non-trainable params: 0
在没有样本权重的情况下进行训练可以model.fit(...)
正常工作,但该sample_weight=weights
参数不支持 3 维数组。
weight_by_prominence
使用scipy.signal.find_peaks
和计算权重的自定义损失调用scipy.signal.peak_prominences
,因为我不知道任何 tf 等价物。我发现成功实现这一点的唯一方法是启用急切执行model.compile(optimizer=optimizer, loss=mse_with_prominence, run_eagerly=True)
,这会导致训练时间大幅增加且不可接受(> 20 倍)。
我正在寻找一种保留 tf 的 Graph 功能优势的解决方案。
我在 Python 3.6.6 上使用 Tensorflow v2.0.1。
下面是代码:函数定义...
def weight_by_prominence(all_signals, max_weight=1, base_weight=0):
assert max_weight>base_weight, f'max_weight {max_weight} must be greater than base_weight {base_weight}'
datapoint_weights = np.ones_like(all_signals)
all_prominences = np.zeros_like(all_signals)
for b, batchdata in enumerate(all_signals):
for i, timehist in enumerate(batchdata.T):
# Locations & prominences of peaks
peak_idx, _ = signal.find_peaks(timehist)
peak_prom = signal.peak_prominences(timehist,peak_idx)[0]
all_prominences[b,peak_idx,i] = peak_prom
# Locations & prominences of valleys (repeat on inverted signal)
valley_idx, _ = signal.find_peaks(-1*timehist)
valley_prom = signal.peak_prominences(-1*timehist,valley_idx)[0]
all_prominences[b,valley_idx,iq] = valley_prom
# Normalise prominences [0,1] and create weights in range [base_weight,max_weight]
datapoint_weights = base_weight + ( all_prominences/np.amax(all_prominences) * (max_weight-base_weight) )
return datapoint_weights
def mse_with_prominence(y_true,y_estimate):
tf.config.experimental_run_functions_eagerly(True)
# Set factor for maximum prominence
k = 5
# Calculate the difference
y_diff = tf.math.squared_difference(y_true,y_estimate)
# Weight the difference
weights = weight_by_prominence(y_diff,k,1)
y_diff = tf.math.multiply(y_diff, tf.convert_to_tensor(weights))
# Take the mean
y_mean = tf.math.reduce_mean(y_diff, axis=1)
return y_mean
...以及它的名称...
model = Model(input_layer, outputs, name='model')
optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss=mse_with_prominence)
大多数操作都可以y_true
使用 tf 方法在输入张量上完成,除了 scipyfind_peaks
和peak_prominence
方法。当尝试在不急切执行的情况下进行训练时,其他几个操作enumerate()
——numpy's.T
和 tf's——.numpy()
会产生不同的错误。
我希望有人可以提出一个解决方案,允许使用 scipy 方法而无需全面使用急切执行?
我知道解决方案的几个潜在相关途径,但不确定它们的适当性。
tf.config.experimental.run_functions_eagerly()
本地使用?- 使用
tf.Function
但保留现有代码结构? - 使用
ft.constant
在拟合之前计算的权重,但是如何将这些传递给损失函数?
由于权重是静态值,我喜欢#3 的想法,但不确定如何实现。在没有解决方案的情况下,我的“hack”解决方案是在模型拟合之前使用 scipy.signal 评估权重,将其保存(写入磁盘?),然后让损失函数在报告损失时读取权重。(需要明确的是,我认为这是低效、不优雅、丑陋的,并不是真正的解决方案。)
解决方案
推荐阅读
- jquery - Jquery SerializeArray 和推送对象列表
- python - 网页抓取
- 标签
- python - 如何更改 pip 安装到错误的 python 副本的事实?
- html - 如何在 jquery LightSlidder 中使用居中图像设置固定高度?
- python - 如何从 findHomography 中获取旋转角度?
- mysql - 在启动时创建数据库视图 - JPA
- pdf - Craft Pro 3.1.33 - 在管理员用户配置文件中下载 PDF
- php - 为什么php函数exec()返回状态码2
- here-api - HERE geocode API 拉丁化地址
- javascript - Express 验证器不适用于 post 值?