python - 当我在 Keras 中加载保存的模型时,输入会丢失,其中包含需要多个输入的自定义损失
问题描述
tf。版本 “1.12.0”
我有一个需要多个输入的自定义损失函数。它工作正常,除非我尝试保存和加载模型。这是一个简单的例子,展示了输入是如何丢失的。请看一看。
x = tf.keras.Input(shape=(5,), name='input')
y_true = tf.keras.Input(shape=(5,), name='y_true' )
y_pred = tf.keras.layers.Dense(5)(x)
other_data = tf.keras.Input(shape=(5,), name='other_data' )
model = tf.keras.Model(inputs=[x, y_true, other_data], outputs=y_pred)
def custom_loss(y_true, y_pred):
return tf.reduce_sum(tf.pow(y_true -y_pred,2)) + tf.reduce_sum(tf.multiply(y_pred,other_data))
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01, beta_2=0.99), loss=custom_loss)
data = np.random.rand(5,5)
model.fit([data, data, data], data)
model.save('tmp.h5')
print(model.input_names)
model1 = tf.keras.models.load_model('tmp.h5', custom_objects={'custom_loss':custom_loss})
print(model1.input_names)
model1.fit([data, data, data], data)
纪元 1/1 5/5 [===============================] - 0s 42ms/步 - 损失:9.4392
['input', 'y_true', 'other_data'] <---------------这很好
['input'] <----------- 这里发生了什么?
回溯(最近一次通话最后):
文件“”,第 21 行,在 model1.fit([data, data, data], data)
文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1536 行,适合 validation_split=validation_split)
文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 992 行,在 _standardize_user_data class_weight,batch_size)
文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1117 行,在 _standardize_weights exception_prefix='input')
文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training_utils.py”,第 293 行,在 standardize_input_data str(len(data)) + ' arrays: ' + str(数据)[:200] + '...')
ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计将看到1个阵列,但得到以下3个阵列的列表:[[[[[0.12768201,0.06106967,0.99779087,0.50767692,0.21839594] 0.40644681、0.69308081、0.30091417、0.776...
解决方案
推荐阅读
- python - Discord.py -- 使用不同的参数执行相同的命令导致不同的结果
- typescript - 明确赋值断言和环境声明有什么区别?
- javascript - 获取请求给出成功响应,但未获取数据
- python - 如何通过迭代行来预测数据框中的每一行?
- user-defined-functions - 有没有办法将雪花视图标记为“安全”以便结果重用?
- python - 左上角的精灵消失了,我不知道如何修复它。PyGame 随机生成的精灵迷宫
- python - 这里的所有输出都存储在一个漂亮的汤变量中,应该将它分成一个数组,有谁知道我该如何修复它?
- python - 需要用 lambda py 函数解释
- terminal - 像less或top这样的程序的底层库/程序是什么
- python - 使用Python对数据帧中的特征列表进行分类编码的for循环