python - ANN回归精度和损失卡住
问题描述
我有一个预测太阳能发电的数据集,该数据集有 20 个独立变量和 1 个依赖变量。我的模型的准确率停留在 60%。我已经尝试了几种模型,但这种准确性是最好的,我可以得到其他更糟糕的东西。这是我的代码:
data_path = r'drive/My Drive/Proj/S.P.F./solarpowergeneration.csv'
dts = pd.read_csv('solarpowergeneration.csv')
dts.head()
X = dts.iloc[:, :-1].values
y = dts.iloc[:, -1].values
print(X.shape, y.shape)
y = np.reshape(y, (-1,1))
y.shape
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
from sklearn.preprocessing import StandardScaler
sc= StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
y_train = sc.fit_transform(y_train)
y_test = sc.transform(y_test)
import keras.backend as K
def calc_accu(y_true, y_pred):
return K.mean(K.equal(K.round(y_true), K.round(y_pred)))
def get_spfnet():
ann = tf.keras.models.Sequential()
ann.add(Dense(X_train.shape[1], activation='relu'))
# ann.add(BatchNormalization())
ann.add(Dropout(0.3))
ann.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
# ann.add(BatchNormalization())
ann.add(Dropout(0.3))
ann.add(Dense(32, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
# ann.add(BatchNormalization())
ann.add(Dropout(0.3))
ann.add(Dense(1))
ann.compile(loss='mse',
optimizer='adam',
metrics=[tf.keras.metrics.RootMeanSquaredError(), calc_accu])
return ann
spfnet = get_spfnet()
#spfnet.summary()
hist = spfnet.fit(X_train, y_train, batch_size=32, epochs=250, verbose=2)
准确率和损失图是
plt.plot(hist.history['calc_accu'])
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.show()
plt.plot(hist.history['root_mean_squared_error'])
plt.title('Model error')
plt.xlabel('Epochs')
plt.ylabel('error')
plt.show()
在 50 个 epoch 之后似乎没有任何改善,两条曲线似乎都没有过度拟合数据我尝试了其他模型,例如减少层和删除内核正则化,使用
kernel_initlizers='normal' and 'he-normal'
但他们表现不佳,停留在 20%。
解决方案
最常见的原因是梯度接近于零。您可能会陷入局部最小值或鞍点。请尝试增加batch_size
(https://www.tensorflow.org/api_docs/python/tf/keras/Sequential/#evaluate)
推荐阅读
- c# - 提高大 EF 多级包含的性能
- github-api - 如何使用 github api 获取我的 github 帐户的所有 github pull 请求
- android - java.lang.RuntimeException:传递结果失败*** 原因:java.lang.NullPointerException:uri
- mysql - 如何在 VueJS 中遍历 JSON?
- swift - 将 AppleScript 字符串列表转换为 Swift 数组
- swift - 随机表格视图单元格高度无法正常工作
- r - Keras 函数式 CNN 模型出错:主输入层的图形断开连接
- c++ - 访问班级内的私人成员?
- python - 重复相同的公式?
- matlab - 在样本的子集中划分样本以进行交叉验证(按学校)(Matlab)