首页 > 解决方案 > 当我在 Keras 中加载保存的模型时,输入会丢失,其中包含需要多个输入的自定义损失

问题描述

tf。版本 “1.12.0”

我有一个需要多个输入的自定义损失函数。它工作正常,除非我尝试保存和加载模型。这是一个简单的例子,展示了输入是如何丢失的。请看一看。

x = tf.keras.Input(shape=(5,), name='input')
y_true = tf.keras.Input(shape=(5,), name='y_true' )
y_pred = tf.keras.layers.Dense(5)(x)
other_data = tf.keras.Input(shape=(5,), name='other_data' )
model = tf.keras.Model(inputs=[x, y_true, other_data],  outputs=y_pred)

def custom_loss(y_true, y_pred):
    return tf.reduce_sum(tf.pow(y_true -y_pred,2)) + tf.reduce_sum(tf.multiply(y_pred,other_data))

model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01, beta_2=0.99), loss=custom_loss)

data = np.random.rand(5,5)
model.fit([data, data, data], data)

model.save('tmp.h5')
print(model.input_names)
model1 = tf.keras.models.load_model('tmp.h5', custom_objects={'custom_loss':custom_loss})
print(model1.input_names)

model1.fit([data, data, data], data)

纪元 1/1 5/5 [===============================] - 0s 42ms/步 - 损失:9.4392

['input', 'y_true', 'other_data'] <---------------这很好

['input'] <----------- 这里发生了什么?

回溯(最近一次通话最后):

文件“”,第 21 行,在 model1.fit([data, data, data], data)

文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1536 行,适合 validation_split=validation_split)

文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 992 行,在 _standardize_user_data class_weight,batch_size)

文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training.py”,第 1117 行,在 _standardize_weights exception_prefix='input')

文件“C:\src\Anaconda3\envs\deepema\lib\site-packages\tensorflow\python\keras\engine\training_utils.py”,第 293 行,在 standardize_input_data str(len(data)) + ' arrays: ' + str(数据)[:200] + '...')

ValueError:检查模型输入时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。预计将看到1个阵列,但得到以下3个阵列的列表:[[[[[0.12768201,0.06106967,0.99779087,0.50767692,0.21839594] 0.40644681、0.69308081、0.30091417、0.776...

标签: pythontensorflowmachine-learningkeras

解决方案


推荐阅读