tensorflow - 如何创建自定义损失函数,将 RNN 的中间训练输出(张量 y_pred)馈送到另一个预定义的 RNN?
问题描述
我希望创建一个自定义损失函数,它不直接使用 RNN(y_pred)的中间输出,而是将 y_pred 作为输入提供给另一个 RNN(比如 RNN2,它已经被定义和训练),并使用这些预测值作为损失函数的参数。
我尝试从 model.compile 函数调用我的自定义损失函数,这会产生错误。是因为我无法将张量数据类型的对象输入 RNN2 吗?假设 y_pred 具有训练的中间输出,我错了吗?使用 sess 的 y_pred 的简单打印命令也会引发错误!IE
sess=tf.Session()
print(sess.eval(y_pred))
那么问题是 y_pred 的根本问题吗?
无论如何,这是代码:
def custom_loss(y_true, y_pred):
predicted=rnn2.predict(y_pred)
return K.mean(K.abs( predicted-y_true), axis=-1)
input_tensor = Input(shape=(1,1))
hidden = LSTM(100, activation='softmax',return_sequences=False)(input_tensor)
out = Dense(1, activation='softmax')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss, optimizer='adam')
错误
You must feed a value for placeholder tensor 'input_19' with dtype float and shape [?,1,1]
[[{{node input_19}}]]
解决方案
这可能是你打电话时要做的事情model.compile
。您是否尝试将它(RNN2)作为新层传递。
RNN2.trainable = False # [1]
model = Sequential()
model.add(RNN1)
model.add(RNN2)
def custom_loss(y_true, y_pred):
# predicted=rnn2.predict(y_pred)
return K.mean(K.abs( y_pred-y_true), axis=-1)
编辑
您能否详细说明我如何将模型(RNN2)用作 RNN1 中的层?
如果是我,我会做这样的事情。
from keras import models, layers
inp = layers.Input((None, 1))
x = layers.LSTM(512, return_sequences=True)(inp)
x = layers.LSTM(256)(x)
x = layers.Dense(32, name='rnn2_output')(x)
rnn2 = models.Model(inp, x)
rnn2.trainable = False # [2]
inp2 = layers.Input((None, 32))
x = layers.LSTM(256, return_sequences=True)(inp2)
x = layers.Dense(1)(x)
x = rnn2(x)
rnn1 = models.Model(inp2, x)
rnn1.summary()
请注意,最近添加的代码 [2] 和旧的代码 [1](最近编辑)都有trainable = False
,这意味着这个模型根本不会被训练。假设您RNN2.predict
输入了损失函数。如果您还想训练它,请删除这些行。
推荐阅读
- assembly - 如何让 BSR 指令在 64 位上工作?
- unix - 使用 awk 比较来自两个不同文件的两个字段
- javascript - 具有多个过滤器的 Javascript 过滤器数组
- vbscript - 使用 vbscript 仅删除文本文件的第一行
- c# - 关闭返回值集中的计数
- javascript - 用正则表达式(正则表达式)条件匹配javascript的变量替换字符串的一部分
- reactjs - 无法从 react js 访问 httponly cookie,但可以在邮递员应用程序中访问!这怎么可能?
- if-statement - 如果用户使用 if/else 在我的代码中写出“是”“否”以外的其他词,我想再次提问
- python - 为什么 md5 算法在哈希时接受 numpy 的 int64 而不是整数?
- c# - 在三元条件语句中更改动态的类型