python - 不使用 keras 后端库的自定义损失函数
问题描述
我正在将 ML 模型应用于实验设置以优化驱动信号。驱动信号本身是被优化的东西,但它的质量是间接评估的(它被应用于实验装置以产生不同的信号)。
我能够通过 python 中的函数从实验中运行和收集数据。
我想建立一个带有自定义损失函数的 ML 模型,该模型使用优化的信号调用实验驱动函数,以获取用于反向传播的错误。
我已经研究过使用 keras,但是必须使用 keras 后端函数的限制意味着我不能在函数中调用我的驱动程序函数。
我想知道如果我在没有 keras 前端的情况下使用张量流,是否有办法做我想做的事情,以及不同的 ML API 是否允许这样做?
谢谢。
解决方案
如果我理解了这个问题,您希望能够根据模型评估损失函数时运行的代码生成损失。
这将是一个例子:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
FACTORS = np.array([[0.5, 2.0, 4.0]])
def ext_function(inputs):
""" This can be an arbitrary python function of the inputs
inputs is a tf.EagerTensor which can be converted into a numpy array.
"""
r = np.dot(inputs, FACTORS.T)
return r
class LossFunction(object):
def __init__(self, model):
# Use model to obtain the inputs
self.model = model
def __call__(self, y_true, y_pred, sample_weight=None):
""" ignore y_true value from fit params and compute it instead using
ext_function
"""
y_true = tf.py_function(ext_function, [self.model.inputs[0]], Tout=tf.float32)
v = keras.losses.mean_squared_error(y_true, y_pred)
return K.mean(v)
def make_model():
inp = Input(shape=(3,))
out = Dense(1, use_bias=False)(inp)
model = Model(inp, out)
model.compile('adam', LossFunction(model))
return model
model = make_model()
model.summary()
测试:
import numpy as np
N_SAMPLES=100
X = np.random.rand(N_SAMPLES, 3)
Y_dummy = np.random.rand(N_SAMPLES)
history = model.fit(X, Y_dummy, epochs=1000, verbose=False)
print(history.history['loss'][-1])
它实际上做了一些事情:
model.layers[1].get_weights()
请注意,简单地生成正确的 Y 值作为输入会简单得多。我不确切知道你的问题的条件。但如果可能的话,尝试预先生成 Y。而不是使用上面的示例。
我已经使用上面的技巧来创建由类加权的自定义指标。即在其中一个输入参数是一个类并且所需的损失函数是每个类的加权平均损失的情况下。
推荐阅读
- javascript - 可以使用 ember-browserify 导入 bpmn-js、bpmn-js-properties-panel 模块
- java - 在 shepfile .shp 中提取几何图形
- ios - UNNotificationAction 按钮未显示,为什么?
- python - 继续用 keras 中保存的模型训练 CNN
- esp8266 - ESP8266 连接到服务器后立即断开连接
- c# - 当您在 xamarin android 中附加了水平布局管理器时,如何获取当前的 recyclerview 项目位置?
- django - ReactJS & Django:如何以正确的方式使用 axios 发送 csrf 令牌?
- c - Wireshark中服务器没有响应,但收到响应,解压失败错误
- php - PHP如何在重新加载后增加变量
- android - 将 lat 和 lng 存储在 arraylist 中