首页 > 解决方案 > 批量大小> 1的自定义TensorFlow损失函数?

问题描述

我有一些带有以下代码片段的神经网络,请注意 batch_size == 1 和 input_dim == output_dim:

net_in = tf.Variable(tf.zeros(shape = [batch_size, input_dim]), dtype=tf.float32)
input_placeholder = tf.compat.v1.placeholder(shape = [batch_size, input_dim], dtype=tf.float32)
assign_input = net_in.assign(input_placeholder) 
# Some matmuls, activations, dropouts, normalizations...
net_out = tf.tanh(output_before_activation)


def loss_fn(output, input):
    #input.shape = output.shape = (batch_size, input_dim)
    output = tf.reshape(output, [input_dim,]) # shape them into 1d vectors
    input = tf.reshape(input, [input_dim,])
    return my_fn_that_only_takes_in_vectors(output, input)

# Create session, preprocess data ...

for epoch in epoch_num:
    for batch in range(total_example_num // batch_size):
        sess.run(assign_input, feed_dict = {input_placeholder : some_appropriate_numpy_array})
        sess.run(optimizer.minimize(loss_fn(net_out, net_in)))

目前,上面的神经网络工作正常,但它非常慢,因为它会更新每个样本的梯度(批量大小 = 1)。我想设置批量大小 > 1,但 my_fn_that_only_takes_in_vectors 无法容纳第一维不是 1 的矩阵。由于我的自定义损失的性质,将批量输入展平为长度向量 (batch_size * input_dim) 似乎不起作用。

既然输入和输出是 N x input_dim 其中 N > 1,我将如何编写新的自定义 loss_fn?在 Keras 中,这不是问题,因为 keras 以某种方式取了批次中每个示例的梯度的平均值。对于我的 TensorFlow 函数,我是否应该将每一行单独作为一个向量,将它们传递给 my_fn_that_only_takes_in_vectors,然后取结果的平均值?

标签: tensorflowneural-network

解决方案


您可以使用一个函数来计算整个批次的损失,并独立处理批次大小。基本上,操作应用于输入的整个第一维(第一维表示批次中的元素编号)。这是一个例子,我希望这有助于了解操作是如何进行的:

    def my_loss(y_true, y_pred):
       dx2 = tf.math.squared_difference(y_true[:, 0], y_true[:, 2])  # shape (BatchSize, )
       dy2 = tf.math.squared_difference(y_true[:, 1], y_true[:, 3])  # shape: (BatchSize, )
       denominator = dx2 + dy2 # shape: (BatchSize, )

       dst_vec = tf.math.squared_difference(y_true, y_pred)  # shape: (Batch, n_labels)
       numerator = tf.reduce_sum(dst_vec, axis=-1)  # shape: (BatchSize,)

       loss_vector = tf.cast(numerator / denominator, dtype="float32") # shape: (BatchSize,) this is a vector containing the loss of each element of the batch

       loss = tf.reduce_sum(loss_vector ) #if you want to sum the losses

       return loss

我不确定您是否需要返回批次损失的总和或平均值。如果求和,请确保使用具有相同批大小的验证数据集,否则损失无法比较。


推荐阅读