首页 > 解决方案 > TensorFlow Keras(v2.2) 模型适合多个输出和损失失败

问题描述

我想使用 TensorFlow Keras(v2.2) 模型适合具有多个输出和损失的 mnist,但它失败了。我的服装模型将返回一个列表 [logits, embedding]。logits 是 2D 张量 [batch, 10],嵌入也是 2D 张量 [batch, 64]。

class MyModel(tf.keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.reshape = tf.keras.layers.Reshape((28, 28, 1))
    self.conv2D1 = tf.keras.layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1, 1), padding='same', activation='relu')
    self.maxPool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding="same")
    self.conv2D2 = tf.keras.layers.Conv2D(filters=8, kernel_size=(3,3), strides=(1, 1), padding='same', activation='relu')
    self.maxPool2 = tf.keras.layers.MaxPooling2D(pool_size=2)
    self.flatten = tf.keras.layers.Flatten(data_format="channels_last")
    self.dropout = tf.keras.layers.Dropout(tf.compat.v1.placeholder_with_default(0.25, shape=[], name="dropout"))
    self.dense1 = tf.keras.layers.Dense(64, activation=None)
    self.dense2 = tf.keras.layers.Dense(10, activation=None)
    
  def call(self, inputs, training):
    x = self.reshape(inputs)
    x = self.conv2D1(x)
    x = self.maxPool1(x)
    if training:
      x = self.dropout(x)
    x = self.conv2D2(x)
    x = self.maxPool2(x)
    if training:
      x = self.dropout(x)
    x = self.flatten(x)
    x = self.dense1(x)
    embedding = tf.math.l2_normalize(x, axis=1)
    logits = self.dense2(embedding)
    return [logits, embedding]

loss_0 是正常的 cross_entropy

def loss_0(y_true, y_pred):
    loss_0 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred[0]))

loss_1 是triplet_semihard_loss

def loss_1(y_true, y_pred):
    loss_1 =  tfa.losses.triplet_semihard_loss(y_true=y_true, y_pred=y_pred[1], distance_metric="L2")
    return loss_1

当我使用模型拟合时,我只能在每个损失中获得 logits 张量。我无法嵌入张量。y_pred[0] 和 y_pred[1] 不起作用。有什么建议吗?

model = MyModel()
model.compile(optimizer=tf.keras.optimizers.Adam(lr=1e-3), loss=[loss_0, loss_1], loss_weights=[0.1, 0.1])
history = model.fit(train_dataset, epochs=5)

标签: tensorflowkeras

解决方案


推荐阅读