python - Keras 神经网络为所有输入预测相同的数字
问题描述
我正在尝试创建一个 keras 神经网络来预测城市两点之间的道路距离。我正在使用谷歌地图来获取旅行距离,然后训练神经网络来做到这一点。
import pandas as pd
arr=[]
for i in range(0,100):
arr.append(generateTwoPoints(55.901819,37.344735,55.589537,37.832254))
df=pd.DataFrame(arr,columns=['p1Lat','p1Lon','p2Lat','p2Lon', 'distnaceInMeters', 'timeInSeconds'])
print(df)
神经网络架构:
from keras.optimizers import SGD
sgd = SGD(lr=0.00000001)
from keras.models import Sequential
from keras.layers import Dense, Activation
model = Sequential()
model.add(Dense(100, input_dim=4 , activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.compile(loss='mse', optimizer='sgd', metrics=['mse'])
然后我分集进行测试/训练
Xtrain=train[['p1Lat','p1Lon','p2Lat','p2Lon']]/100
Ytrain=train[['distnaceInMeters']]/100000
Xtest=test[['p1Lat','p1Lon','p2Lat','p2Lon']]/100
Ytest=test[['distnaceInMeters']]/100000
然后我将数据拟合到模型中,但损失保持不变:
history = model.fit(Xtrain, Ytrain,
batch_size=1,
epochs=1000,
# We pass some validation for
# monitoring validation loss and metrics
# at the end of each epoch
validation_data=(Xtest, Ytest))
我稍后打印数据:
prediction = model.predict(Xtest)
print(prediction)
print (Ytest)
但是所有输入的结果都是相同的:
[[0.26150784]
[0.26171574]
[0.2617755 ]
[0.2615582 ]
[0.26173398]
[0.26166356]
[0.26185763]
[0.26188275]
[0.2614446 ]
[0.2616575 ]
[0.26175532]
[0.2615183 ]
[0.2618127 ]]
distnaceInMeters
2 0.13595
6 0.27998
7 0.48849
16 0.36553
21 0.37910
22 0.40176
33 0.09173
39 0.24542
53 0.04216
55 0.38212
62 0.39972
64 0.29153
87 0.08788
我找不到问题。它是什么?我是机器学习的新手。
解决方案
您犯了一个非常基本的错误:由于您处于回归设置中,因此您不应在最后一层使用sigmoid
激活(这用于二元分类情况);将最后一层更改为
model.add(Dense(1,activation='linear'))
甚至
model.add(Dense(1))
因为,根据文档,如果您不指定activation
参数,则默认为linear
.
其他答案中已经提供了各种其他建议,并且评论可能很有用(较低的 LR、更多层、其他优化器,例如Adam
),您当然需要增加批量大小;但是sigmoid
您当前用于最后一层的激活功能将无法使用。
与问题无关,但在回归设置中,您不需要将损失函数作为指标重复;这个
model.compile(loss='mse', optimizer='sgd')
就足够了。
推荐阅读
- r - R中指定长度的所有简单路径
- django - Django 限制模型对象的用户查看权限
- java - 我应该使用什么 Hibernate Query 来检索 MS SQL 中的最新记录?
- python - 如何检查一个角色是否在另一个角色前面?
- realex-payments-api - 全球支付 HPP 沙箱。使用 ngrok 时 MERCHANT_RESPONSE_URL 中的 508 个无效字符
- python-3.x - Python Pandas用动态名称修剪列的空白
- python - 简单非线性回归的 Keras 预测
- javascript - Moment.js - 如何获取用户时区?
- sql - SQL更新(案例)多个值
- python - 使用 joblib 保存的模型给出不同的分数