首页 > 解决方案 > GradientTape 在使用自定义损失函数时返回 None

问题描述

给定一个特定的文本输入,我试图创建一个与输入具有相同语义编码的输出。为此,我训练了一个自动编码器,只保留了编码器部分来比较序列嵌入。这是对新解码器进行训练的代码:

with tf.GradientTape() as gen_tape:
    enc_output, enc_hidden = enc(input_batch, enc_hidden)
    gen_hidden = enc_hidden
    all_outputs = [[tokenizer.word_index[START_TOKEN]] * BATCH_SIZE]
    gen_input = tf.expand_dims([tokenizer.word_index[START_TOKEN]] * BATCH_SIZE, 1) #First input is list of start tokens
    gen_loss = 0
    for t in range(1, input_batch.shape[1]):
        predictions, gen_hidden, _ = gen(gen_input, gen_hidden, enc_output)
        predictions_am = tf.expand_dims(tf.argmax(predictions, 1), 1) #take most likely prediction for each row
        all_outputs.append(tf.argmax(predictions, 1))
        gen_input = predictions_am #predicted IDs are fed back into the model

    all_outputs = tf.stack(all_outputs, 1) #build list of full length predictions
    #Get the embedding vectors for original and predictions
    e1 = enc(all_outputs, enc.get_def_hidden_state())[0]
    e2 = enc_output
    gen_loss = -tf.keras.losses.cosine_similarity(e1, e2) + 1 #calculate loss based on how similar they are

gen_grads = gen_tape.gradient(gen_loss, gen.trainable_weights)
gen_optimizer.apply_gradients(zip(gen_grads, gen.trainable_weights))

gen_grads 总是最终成为一个列表 None

标签: python-3.xkerastensorflow2.0

解决方案


Argmax is not differntiable. You can't have it as model outputs for loss calculation. You need to keep the one-hot predictions as they are until the end.


推荐阅读