tensorflow - 批量大小> 1的自定义TensorFlow损失函数?
问题描述
我有一些带有以下代码片段的神经网络,请注意 batch_size == 1 和 input_dim == output_dim:
net_in = tf.Variable(tf.zeros(shape = [batch_size, input_dim]), dtype=tf.float32)
input_placeholder = tf.compat.v1.placeholder(shape = [batch_size, input_dim], dtype=tf.float32)
assign_input = net_in.assign(input_placeholder)
# Some matmuls, activations, dropouts, normalizations...
net_out = tf.tanh(output_before_activation)
def loss_fn(output, input):
#input.shape = output.shape = (batch_size, input_dim)
output = tf.reshape(output, [input_dim,]) # shape them into 1d vectors
input = tf.reshape(input, [input_dim,])
return my_fn_that_only_takes_in_vectors(output, input)
# Create session, preprocess data ...
for epoch in epoch_num:
for batch in range(total_example_num // batch_size):
sess.run(assign_input, feed_dict = {input_placeholder : some_appropriate_numpy_array})
sess.run(optimizer.minimize(loss_fn(net_out, net_in)))
目前,上面的神经网络工作正常,但它非常慢,因为它会更新每个样本的梯度(批量大小 = 1)。我想设置批量大小 > 1,但 my_fn_that_only_takes_in_vectors 无法容纳第一维不是 1 的矩阵。由于我的自定义损失的性质,将批量输入展平为长度向量 (batch_size * input_dim) 似乎不起作用。
既然输入和输出是 N x input_dim 其中 N > 1,我将如何编写新的自定义 loss_fn?在 Keras 中,这不是问题,因为 keras 以某种方式取了批次中每个示例的梯度的平均值。对于我的 TensorFlow 函数,我是否应该将每一行单独作为一个向量,将它们传递给 my_fn_that_only_takes_in_vectors,然后取结果的平均值?
解决方案
您可以使用一个函数来计算整个批次的损失,并独立处理批次大小。基本上,操作应用于输入的整个第一维(第一维表示批次中的元素编号)。这是一个例子,我希望这有助于了解操作是如何进行的:
def my_loss(y_true, y_pred):
dx2 = tf.math.squared_difference(y_true[:, 0], y_true[:, 2]) # shape (BatchSize, )
dy2 = tf.math.squared_difference(y_true[:, 1], y_true[:, 3]) # shape: (BatchSize, )
denominator = dx2 + dy2 # shape: (BatchSize, )
dst_vec = tf.math.squared_difference(y_true, y_pred) # shape: (Batch, n_labels)
numerator = tf.reduce_sum(dst_vec, axis=-1) # shape: (BatchSize,)
loss_vector = tf.cast(numerator / denominator, dtype="float32") # shape: (BatchSize,) this is a vector containing the loss of each element of the batch
loss = tf.reduce_sum(loss_vector ) #if you want to sum the losses
return loss
我不确定您是否需要返回批次损失的总和或平均值。如果求和,请确保使用具有相同批大小的验证数据集,否则损失无法比较。
推荐阅读
- x86 - 为什么只有在存在存储初始化循环时才计算用户模式 L1 存储未命中事件?
- java - 投射到客户时出现条纹 Java API Webhook 错误
- flask-appbuilder - Flask appbuilder 大型应用程序
- java - 调用 Driver#connect 时出错,并且无法创建与数据库服务器的连接
- javascript - 如何使用 lodash 根据值对嵌套数组进行排序?
- fonts - 无法从 Windows 10 中删除 Google 字体
- c# - 试图理解函数之间的 C#、WPF 和用户输入
- javascript - 选择字段中的默认值
- coq - 证明在列表中找到相同元素的另一个属性
- nginx - K8S - Ingress - Route 2 不同的应用程序