python - 使用权重时,Tensorflow 为“tf.keras”模型计算不正确的损失
问题描述
使用 时,损失计算不正确tf.keras
。建立模型后,tf.keras.fit_generator
应接受(inputs, targets, sample_weights)
作为输入。但是,如果我乘以sample_weights
10000,损失不会改变。
该错误似乎从 Tensorflow 的 1.10 版本开始出现,例如(1.11、1.12)
重现代码
import numpy as np
import tensorflow as tf
WEIGHT_VARIABLE = 1
no_of_features = 10
timesteps = 3
batch_size = 32
def data_gen():
while True:
numerical = np.random.randint(5, size=(batch_size, timesteps, no_of_features))
y = np.random.randint(2, size=batch_size)
w = np.ones(batch_size) * WEIGHT_VARIABLE
yield {'numeric_input': numerical}, y, w
def build_model():
numerical_input = tf.keras.layers.Input(shape=(timesteps, no_of_features), name='numeric_input')
rnn_out = tf.keras.layers.GRU(32, return_sequences=False)(numerical_input)
dense = tf.keras.layers.Dense(1, activation='sigmoid', name='main_output')(rnn_out)
model = tf.keras.models.Model(numerical_input, dense)
params = {
'loss': 'binary_crossentropy',
'optimizer': tf.keras.optimizers.Adam(),
'metrics': [tf.keras.metrics.binary_crossentropy, tf.keras.metrics.binary_accuracy]
}
model.compile(**params)
return model
def train_model():
gen1 = data_gen()
model = build_model()
model.fit_generator(gen1, epochs=30, steps_per_epoch=10)
if __name__ == '__main__':
train_model()
在上面的代码中,您只需将WEIGHT_VARIABLE = 1
From 1 更改为 100000 并重新运行文件。
日志
v1.10
WEIGHT_VARIABLE = 1
Epoch 1/5 10/10 [==============================] -
1s 128ms/step - loss: 0.7407 - binary_crossentropy: 0.7407 - binary_accuracy: 0.5031
Epoch 2/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7043 - binary_crossentropy: 0.7043 - binary_accuracy: 0.5125
Epoch 3/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7055 - binary_crossentropy: 0.7055 - binary_accuracy: 0.5219
Epoch 4/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7002 - binary_crossentropy: 0.7002 - binary_accuracy: 0.5250
Epoch 5/5 10/10 [==============================] -
0s 4ms/step - loss: 0.6944 - binary_crossentropy: 0.6944 - binary_accuracy: 0.5375
WEIGHT_VARIABLE = 10000
Epoch 1/5 10/10 [==============================] -
1s 131ms/step - loss: 7235.5976 - binary_crossentropy: 0.7236 - binary_accuracy: 0.4562
Epoch 2/5 10/10 [==============================] -
0s 4ms/step - loss: 7271.9184 - binary_crossentropy: 0.7272 - binary_accuracy: 0.4844
Epoch 3/5 10/10 [==============================] -
0s 4ms/step - loss: 7276.9147 - binary_crossentropy: 0.7277 - binary_accuracy: 0.4500
Epoch 4/5 10/10 [==============================] -
0s 4ms/step - loss: 7052.0121 - binary_crossentropy: 0.7052 - binary_accuracy: 0.4625
Epoch 5/5 10/10 [==============================] -
0s 4ms/step - loss: 7187.0285 - binary_crossentropy: 0.7187 - binary_accuracy: 0.4969
v1.12
WEIGHT_VARIABLE = 1
Epoch 1/5 10/10 [==============================] -
1s 68ms/step - loss: 0.7188 - binary_crossentropy: 0.7188 - binary_accuracy: 0.5312
Epoch 2/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7044 - binary_crossentropy: 0.7044 - binary_accuracy: 0.4969
Epoch 3/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7086 - binary_crossentropy: 0.7086 - binary_accuracy: 0.4844
Epoch 4/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7075 - binary_crossentropy: 0.7075 - binary_accuracy: 0.4500
Epoch 5/5 10/10 [==============================] -
0s 4ms/step - loss: 0.6950 - binary_crossentropy: 0.6950 - binary_accuracy: 0.5187
WEIGHT_VARIABLE = 10000
Epoch 1/5 10/10 [==============================] -
1s 74ms/step - loss: 0.9084 - binary_crossentropy: 0.9084 - binary_accuracy: 0.4719
Epoch 2/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7120 - binary_crossentropy: 0.7120 - binary_accuracy: 0.5062
Epoch 3/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7024 - binary_crossentropy: 0.7024 - binary_accuracy: 0.5344
Epoch 4/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7257 - binary_crossentropy: 0.7257 - binary_accuracy: 0.4500
Epoch 5/5 10/10 [==============================] -
0s 4ms/step - loss: 0.7013 - binary_crossentropy: 0.7013 - binary_accuracy: 0.4844
解决方案
推荐阅读
- html - 如何更改 IE 11 滚动条的宽度?
- c++ - 没有定义的自动变量声明
- javascript - 如何在节点 js 中多次覆盖和追加数据到同一个文件
- kubernetes - oauth2 代理边车的目的是什么?
- php - Laravel 5.6: ErrorException: compact(): Undefined variable: operator
- xml - 如何从 XmlEvent 访问 XML 元素名称
- flutter - 当该属性更改时,如何对显示对象属性的文本小部件进行反应性更新?这与 GetX
- elasticsearch - 如何在一个仪表板kibana中显示不同的时间段?
- linux - 为什么 ldconfig 能够找到一个库,而不是 Rust?
- c# - 并发 AWS Lambdas 运行不止一次