python - 超参数调整以决定最佳神经网络
问题描述
我想根据某些标准找出哪个是最佳神经网络。标准如下:
使用一、二、三、四隐藏层 + 输出层测试 4 种架构
要测试的学习率:0.1,0.01,0.001
要测试的时期:10,50,100
输入尺寸 = 20
输出应该是一个表格,显示每个组合(36 行)。例如,有一个隐藏层,lr = 0.1,epochs = 10,准确率是 X。
请看我下面的代码:
#Function to create the model
def create_model(layers,learn_rate):
model = Sequential()
for i, nodes in enumerate(layers):
if i==0:
model.add(Dense(nodes),input_dim = 20,activation = 'relu')
else:
model.add(Dense(nodes),activation = 'relu')
model.add(Dense(units = 4,activation = 'softmax'))
model.compile(optimizer=adam(lr=learn_rate), loss='categorical_crossentropy',metrics=['accuracy'])
return model
#Initialization of variables
#Here there are the four possible types of layers with the neurons in each.
layers = [[20], [40, 20], [45, 30, 15],[32,16,8,4]]
learn_rate = [0.1,0.01,0.001]
epochs = [10,50,100]
#GridSearchCV for hyperparameter tuning
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
model = KerasClassifier(build_fn = create_model, verbose = 0)
param_grid = dict(layers = layers,learn_rate = learn_rate,epochs = epochs)
grid = GridSearchCV(estimator = model, param_grid = param_grid,cv = 3)
grid_result = grid.fit(train_x,train_y)
但是当我运行代码时,出现以下错误:
RuntimeError: Cannot clone object <keras.wrappers.scikit_learn.KerasClassifier object at 0x000001AA272C7748>, as the constructor either does not set or modifies parameter layers
解决方案
无法克隆对象不是主要问题。这是模型生成器函数中另一个错误的结果。您在create_model()
中有一些语法错误。请查看输出中“克隆问题”之前的错误。
这是固定功能:
from keras import optimizers
def create_model(layers, learn_rate):
model = Sequential()
for i, nodes in enumerate(layers):
if i==0:
model.add(Dense(nodes,input_dim = 20,activation = 'relu'))
else:
model.add(Dense(nodes,activation = 'relu'))
model.add(Dense(units = 4,activation = 'softmax'))
model.compile(optimizer=optimizers.adam(lr=learn_rate), loss='categorical_crossentropy',metrics=['accuracy'])
return model
推荐阅读
- python - 如何在 y=x 的函数的一行中编写 plt.scatter(x, y) 函数
- c# - C# - 由另一个数组拆分数组
- javascript - 如何将完整的字符串内容放入js
- sql - 查询表达式中的语法错误(缺少运算符) - 我在哪里将它括起来?
- statsmodels - 阐明用于时间序列预测的 statsmodels AutoReg()、ARMA() 和 SARIMAX()
- django - 无法访问 Django Graphene 中的字段
- python - 网络抓取问题(空列表)
- python - 更改列表列表的元素
- angular - 在继续之前解决 for 循环中的多个承诺
- python-3.x - 斐波那契。序列 UnboundLocalError