deep-learning - 为什么模型的损失在每个时期总是围绕 1 旋转?
问题描述
在训练期间,我的模型损失围绕“1”旋转。它没有收敛。我尝试了各种优化器,但它仍然显示相同的模式。我正在使用带有 tensorflow 后端的 keras。可能的原因是什么?任何帮助或参考链接都将是可观的。
def model_vgg19():
vgg_model = VGG19(weights="imagenet", include_top=False, input_shape=(128,128,3))
for layer in vgg_model.layers[:10]:
layer.trainable = False
intermediate_layer_outputs = get_layers_output_by_name(vgg_model, ["block1_pool", "block2_pool", "block3_pool", "block4_pool"])
convnet_output = GlobalAveragePooling2D()(vgg_model.output)
for layer_name, output in intermediate_layer_outputs.items():
output = GlobalAveragePooling2D()(output)
convnet_output = concatenate([convnet_output, output])
convnet_output = Dense(2048, activation='relu')(convnet_output)
convnet_output = Dropout(0.6)(convnet_output)
convnet_output = Dense(2048, activation='relu')(convnet_output)
convnet_output = Lambda(lambda x: K.l2_normalize(x,axis=1)(convnet_output)
final_model = Model(inputs=[vgg_model.input], outputs=convnet_output)
return final_model
model=model_vgg19()
这是我的损失函数:
def hinge_loss(y_true, y_pred):
y_pred = K.clip(y_pred, _EPSILON, 1.0-_EPSILON)
loss = tf.convert_to_tensor(0,dtype=tf.float32)
g = tf.constant(1.0, shape=[1], dtype=tf.float32)
for i in range(0, batch_size, 3):
try:
q_embedding = y_pred[i+0]
p_embedding = y_pred[i+1]
n_embedding = y_pred[i+2]
D_q_p = K.sqrt(K.sum((q_embedding - p_embedding)**2))
D_q_n = K.sqrt(K.sum((q_embedding - n_embedding)**2))
loss = (loss + g + D_q_p - D_q_n)
except:
continue
loss = loss/(batch_size/3)
zero = tf.constant(0.0, shape=[1], dtype=tf.float32)
return tf.maximum(loss,zero)
解决方案
绝对是一个问题,您将数据打乱,然后尝试从中学习三元组。
正如您在此处看到的:https: //keras.io/models/model/model.fit 在每个 epoch 中打乱您的数据,使您的三元组设置过时。尝试将 shuffle 参数设置为 false 看看会发生什么,也可能会有不同的错误。
推荐阅读
- c# - PostSharp 5.1.9 连接到管道服务器时出错
- c# - EditorGUILayout.TextField 不返回更新的 inputText
- django-rest-framework - 我已经嵌套了序列化程序,并希望通过覆盖 create 方法为其创建一个实例。
- python - 为什么 numpy.where 给我这个输出?
- python - Python仅保留列表中的字母数字单词
- javascript - 如何在反应路由器 3 中设置主应用程序外部的路由?
- java - JPA Left Join 正在生成返回太多行的 SQL
- regex - QRegExp 多行带引号
- java - 带有彗星处理器的此 URL 不支持 HTTP 方法 GET
- azure - Azure 搜索 - SearchMode:ANY - 未按预期工作