machine-learning - 使用 keras 在深度学习模型中训练数据形状误差
问题描述
我正在为多类分类问题创建一个深度学习模型,我的模型包含 46 个独特的类。我的X_train形状是(14382, 183),& y_train是14382
代码-
#Creating Dummy Variables
X=pd.get_dummies(X, prefix=list((X.select_dtypes(include=[object])).columns))
#Splitting the dataset
from sklearn.model_selection import train_test_split
X_train, X_valid, y_train, y_valid= train_test_split(X, y, test_size=0.3, random_state=10)
model = Sequential()
#adding layers to the model
model.add(Dense(units =367, activation ='relu', input_dim =183))
model.add(Dense(units =182, activation ='relu'))
model.add(Dense(units =182, activation='relu'))
#output layer
model.add(Dense(46, activation='softmax'))
model.compile(loss = 'categorical_crossentropy' , optimizer = keras.optimizers.Adam(learning_rate=0.0001) , metrics = ['accuracy'] )
model.fit(X_train, y_train, epochs=20, batch_size = 50, validation_data=(X_valid, y_valid))
运行模型后我遇到了一个错误-
ValueError: Error when checking target: expected dense_110 to have shape (46,) but got array with shape (1,)
我该如何解决这个错误?
解决方案
问题在于您的目标形状
实际上它是一维的,所以你可以保持原样并应用sparse_categorical_crossentropy
为损失函数
X = np.random.randint(0,10, (1000,100))
y = np.random.randint(0,3, 1000)
model = Sequential([
Dense(128, input_dim = 100),
Dense(3, activation='softmax'),
])
model.summary()
model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(X, y, epochs=3)
否则,您可以对其进行一次热编码pd.get_dummies(y).values
(在训练测试拆分之前)并获得一个 y 形状(n_sample,n_class)。在这种情况下,您可以使用categorical_crossentropy
X = np.random.randint(0,10, (1000,100))
y = pd.get_dummies(np.random.randint(0,3, 1000)).values
model = Sequential([
Dense(128, input_dim = 100),
Dense(3, activation='softmax'),
])
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])
history = model.fit(X, y, epochs=3)
推荐阅读
- c# - 如何按其他对象的属性对对象列表进行排序
- python - 单元测试中的请求对象
- arrays - SQLite JSON_EXTRACT 数组中 1 个对象的所有值
- python - SumTree 实现无法正常工作
- javascript - 比较数组中的图像
- python - 创建一个python包:如何导入模块的内容,而不是模块本身?
- javascript - SuperTest / JEST,如何强制或模拟 400 状态码?
- python - 如何在循环中对多个列表的元素求和?
- postman - 通过 Postman 进行 Etsy OAuth2.0 调用
- flutter - 带有 Google google_ml_kit 的条形扫描仪:^0.6.0