python - Tensorflow-为什么我的 ANN 模型不学习
问题描述
这是我非常基本的ANN代码:
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize
data = pd.read_csv("home_data.csv")
x = data.drop(['id', 'date', 'price'], axis=1).values
y = data['price'].values
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)
model = Sequential()
model.add(Dense(18, input_shape=(18,), activation="sigmoid"))
model.add(Dense(36, input_shape=(18,), activation="sigmoid"))
model.add(Dense(1, input_shape=(18,), activation="sigmoid"))
model.compile(optimizer='sgd', loss='mean_squared_error')
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=50)
plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()
然而,我的损失非常高 - 大约 426470263086 - 并且永远不会随着时间的推移而减少。这是我的损失图
更新
这是我正在尝试处理的部分数据。
id date price bedrooms ... lat long sqft_living15 sqft_lot15
0 7129300520 20141013T000000 221900.0 3 ... 47.5112 -122.257 1340 5650
1 6414100192 20141209T000000 538000.0 3 ... 47.7210 -122.319 1690 7639
2 5631500400 20150225T000000 180000.0 2 ... 47.7379 -122.233 2720 8062
3 2487200875 20141209T000000 604000.0 4 ... 47.5208 -122.393 1360 5000
4 1954400510 20150218T000000 510000.0 3 ... 47.6168 -122.045 1800 7503
[5 rows x 21 columns]
解决方案
看起来您正在尝试预测连续值。当预测连续值时,您在最后一层的激活应该是线性的,或者泄漏的 relu(如果预测值为正),否则没有激活。
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Sequential
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler,normalize
data = pd.read_csv("home_data.csv")
x = data.drop(['id','price' ,'date'], axis=1).values
y = data['price'].values
x_train, x_test, y_train, y_test = train_test_split(x,y,test_size=0.33)
scaler = StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
model = Sequential()
model.add(Dense(12, input_shape=(18,), activation="relu"))
model.add(Dense(6, activation="relu"))
model.add(Dense(1, activation="linear"))
model.compile(optimizer='sgd', loss='mean_squared_error', metrics = [tf.keras.metrics.RootMeanSquaredError()])
r = model.fit(x_train, y_train, validation_data=(x_test,y_test), epochs=10)
plt.plot(r.history['loss'], label="loss")
plt.plot(r.history['val_loss'], label="val_loss")
plt.show()
您不必为隐藏层指定输入形状。
模型的计算损失非常高,因为数据集中的最小值和最大值变化很大。
使用标准洁牙机后,损失减少了。
输出:
Epoch 1/10
453/453 [==============================] - 1s 3ms/step - loss: 6093344963084377128960.0000 - root_mean_squared_error: 78059880448.0000 - val_loss: 9416156905472.0000 - val_root_mean_squared_error: 3068575.7500
Epoch 2/10
453/453 [==============================] - 1s 3ms/step - loss: 639826591744.0000 - root_mean_squared_error: 799891.6250 - val_loss: 155623915520.0000 - val_root_mean_squared_error: 394491.9688
Epoch 3/10
453/453 [==============================] - 1s 2ms/step - loss: 124726026240.0000 - root_mean_squared_error: 353165.7188 - val_loss: 155318534144.0000 - val_root_mean_squared_error: 394104.7188
Epoch 4/10
453/453 [==============================] - 1s 3ms/step - loss: 124705193984.0000 - root_mean_squared_error: 353136.2188 - val_loss: 155418017792.0000 - val_root_mean_squared_error: 394230.9062
Epoch 5/10
453/453 [==============================] - 1s 3ms/step - loss: 124720766976.0000 - root_mean_squared_error: 353158.2812 - val_loss: 155389984768.0000 - val_root_mean_squared_error: 394195.3750
Epoch 6/10
453/453 [==============================] - 1s 3ms/step - loss: 124696051712.0000 - root_mean_squared_error: 353123.2812 - val_loss: 155291697152.0000 - val_root_mean_squared_error: 394070.6875
Epoch 7/10
453/453 [==============================] - 1s 3ms/step - loss: 124681125888.0000 - root_mean_squared_error: 353102.1562 - val_loss: 155307376640.0000 - val_root_mean_squared_error: 394090.5625
Epoch 8/10
453/453 [==============================] - 1s 3ms/step - loss: 124710920192.0000 - root_mean_squared_error: 353144.3438 - val_loss: 155327266816.0000 - val_root_mean_squared_error: 394115.8125
Epoch 9/10
453/453 [==============================] - 1s 3ms/step - loss: 124708052992.0000 - root_mean_squared_error: 353140.2812 - val_loss: 155288338432.0000 - val_root_mean_squared_error: 394066.4062
Epoch 10/10
453/453 [==============================] - 1s 3ms/step - loss: 124725968896.0000 - root_mean_squared_error: 353165.6250 - val_loss: 155315683328.0000 - val_root_mean_squared_error: 394101.0938
推荐阅读
- node.js - 当我为 NodeJS 使用 http-server 时如何访问 index.html 页面?
- windows - Powershell Loop 不等待命令完成
- jquery - 如何从两个不同的复选框值更改跨度文本
- c - 为什么允许`typedef struct xx`?
- javascript - Padding 防止元素渲染,Chrome v75 中的 Flexbox 错误
- mysql - 我正在使用按关键字分组,但它将“h”和“H”分组为一个
- shell - 为什么shell printf 在不引用时会一起运行字符串?
- java - 这个用于 2D 凸平面的凸壳算法叫什么?
- bash - 安装 Cygwin 后出错:bash:错误替换:没有关闭“`”
- c# - 使用 LINQ 从集合中选择所有字段以及另一个集合中的相关项