tensorflow - 由 TensorFlow2 训练的简单线性回归模型的池性能
问题描述
我的模型很简单y = 2*x + 200 + error
,但我无法以简单的方式得到正确的结果。我不知道发生了什么。
import numpy as np
from tensorflow import keras
x = np.arange(100)
error = np.random.rand(100,1).ravel()
y = 2*x + 200 + error
opt = keras.optimizers.Adam(lr=0.0005)
model = keras.Sequential([keras.layers.Dense(1, input_shape=[1])])
model.compile(optimizer=opt, loss='mse', metrics=['mae'])
early_stopping_callback = keras.callbacks.EarlyStopping(
patience=10,
monitor='val_loss',
mode='min',
restore_best_weights=True)
history = model.fit(x, y, epochs=2000, batch_size=16, verbose=1,
validation_split=0.2, callbacks=[early_stopping_callback])
当验证损失很大时,我的模型总是停止:
纪元 901/2000 5/5 [===============================] - 0s 3ms/步 - 损失:14767.1357 - val_loss : 166.8979
而且我不断收到不正确的训练后参数:
model.weights
[<tf.Variable 'dense_28/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[4.2019334]], dtype=float32)>,
<tf.Variable 'dense_28/bias:0'形状=(1,)dtype=float32,numpy=array([2.611792],dtype=float32)>]
请帮我弄清楚我的代码有什么问题。
我使用 tensorflow-v2.3.0
解决方案
我明白了,主要问题是EarlyStopping
太早停止了我的训练过程!另一个问题是学习率太小。
因此,当我更改两个参数设置时,我得到了正确的结果:
import numpy as np
from tensorflow import keras
x = np.arange(100)
error = np.random.rand(100,1).ravel()
y = 2*x + 200 + error
opt = keras.optimizers.Adam(lr=0.8) # <--- bigger lr
model = keras.Sequential([keras.layers.Dense(1, input_shape=[1])])
model.compile(optimizer=opt, loss='mse', metrics=['mae'])
early_stopping_callback = keras.callbacks.EarlyStopping(
patience=100, # <--- longer patience to training
monitor='val_loss',
mode='min',
restore_best_weights=True)
history = model.fit(x, y, epochs=2000, batch_size=16, verbose=1,
validation_split=0.2, callbacks=[early_stopping_callback])
推荐阅读
- java - 材料设计在android studio中不起作用
- function - Woocommerce:在功能中更改产品类型
- jquery - 如何让 jQuery 在页面加载时单击随机的 li 元素?
- amazon-web-services - Google Analytics 无法在 AWS 托管的网站上运行
- reactjs - 使用 react-ga 触发事件返回“命令被忽略。未知目标:未定义”
- c# - AutoMapper:避免使用 MaxDepth 进行无限递归?
- powershell - 在 PowerShell 或 Batch 中格式化日志文件
- excel - 如何在 Excel 中可用的 VBA 中使用 TRUNC 函数
- sql - 在非常大的数据库中查询字符串的最佳方法?
- c++ - 关于类成员的 C++ 问题