python-3.x - 使用 keras 的多项式回归
问题描述
嗨,我是 keras 的新手,我只是想知道 ann 对多项式回归任务有什么好处,或者我们应该只使用 sklearn 例如我写这个脚本
import numpy as np
import keras
from keras.layers import Dense
from keras.models import Sequential
x=np.arange(1, 100)
y=x**2
model = Sequential()
model.add(Dense(units=200, activation = 'relu',input_dim=1))
model.add(Dense(units=200, activation= 'relu'))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error',optimizer=keras.optimizers.SGD(learning_rate=0.001))
model.fit(x, y,epochs=2000)
但是在对一些数字进行测试后,我没有得到像这样的好结果:
model.predict([300])
array([[3360.9023]], dtype=float32)
我的代码中是否有任何问题,或者我不应该将 ann 用于多项式回归。谢谢你。
解决方案
我不是 100% 肯定,但我认为你得到如此糟糕的预测的原因是因为你没有扩展你的数据。人工神经网络的计算量非常大,因此必须进行缩放。缩放数据,如下所示:
import numpy as np
import keras
from keras.layers import Dense
from keras.models import Sequential
x=np.arange(1, 100)
y=x**2
from sklearn.preprocessing import StandardScaler
sc_x = StandardScaler()
x = sc_x.fit_transform(x)
sc_y = StandardScaler()
y = sc_y.fit_transform(y)
model = Sequential()
model.add(Dense(units=5, activation = 'relu',input_dim=1))
model.add(Dense(units=5, activation= 'relu'))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error',optimizer=keras.optimizers.SGD(learning_rate=0.001))
model.fit(x, y,epochs=75, batch_size=10)
prediction = sc_y.inverse_transform(model.predict(sc_x.transform([300])))
print(prediction)
请注意,我将 epoch 的数量从 2000 更改为 75。这是因为 2000 epoch 对于神经网络来说太高了,并且需要大量时间来训练。您的 X 数据集仅包含 100 个值,因此我建议的最大时期数为 75。
此外,我还将每个隐藏层中的神经元数量从 200 个更改为 5 个。这是因为 200 个神经元对于大多数数据集来说远远不够,更不用说长度为 100 的小数据集了。
这些更改应确保您的神经网络产生更准确的预测。
希望有帮助。
推荐阅读
- php - 注意:未定义索引:第 9 行 C:\xampp\htdocs\template\cek_login.php 中的用户名
- java - 将 onMouseEnter 和 onMouseLeave 事件添加到音乐播放器
- csv - Pyspark:无法在 Zeppilin 实例中导入 csv 文件
- python - 无法在 Pandas groupby 聚合中使用某些基本统计功能
- python - 将未分配给变量的列表添加到指向不同列表的另一个变量是否会创建新的列表对象?
- user-input - 我的问题是关于为什么 R 的 rbind 函数不适用于古吉拉特语文本输入
- javascript - JQuery Mobile 1.4.5 似乎没有在 ipad 上调用 javascript
- javascript - 引用自身的 Javascript 对象...有什么问题吗?
- css - 'a:link {color}' CSS 选择器的问题
- c# - Selenium WebDriver - 突出显示解决方案中的所有 Web 元素