python - keras ANN 中的输入和输出形状
问题描述
我正在尝试使用 keras 为多类分类任务实现 ANN。这是我的数据集:
- 特征形状 (9498, 17)
- 标签形状 (9498,)
其中 9498 是像素数,17 是时间戳数,我有 24 个要预测的类。我想从一些非常基本的东西开始。这是我使用的代码:
import keras
from keras.models import Sequential
from keras.layers import Dense
# Loading the data
X_train, X_test, y_train, y_test = train_test_split(NDVI, labels, test_size=0.15, random_state=42)
# Building the model
model = Sequential([
Dense(128, activation='relu', input_shape=(17,),name="layer1"),
Dense(64, activation='relu', name="layer2"),
Dense(24, activation='softmax', name="layer3"),
])
print(model.summary())
# Compiling the model
model.compile(
optimizer='adam', # gradient-based optimizer
loss='categorical_crossentropy', # (>2 classes)
metrics=['accuracy'],
)
# Training the model
model.fit(
X_train, # training data
y_train, # training targets
epochs=5,
batch_size=32,
)
这会导致以下错误:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-17-2f4cf6510b24> in <module>()
23 y_train, # training targets
24 epochs=5,
---> 25 batch_size=32,
26 )
2 frames
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
1152 sample_weight=sample_weight,
1153 class_weight=class_weight,
-> 1154 batch_size=batch_size)
1155
1156 # Prepare validation data.
/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
619 feed_output_shapes,
620 check_batch_axis=False, # Don't enforce the batch size.
--> 621 exception_prefix='target')
622
623 # Generate sample-wise weight values given the `sample_weight` and
/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
143 ': expected ' + names[i] + ' to have shape ' +
144 str(shape) + ' but got array with shape ' +
--> 145 str(data_shape))
146 return data
147
ValueError: Error when checking target: expected layer3 to have shape (24,) but got array with shape (1,)
我不知道为什么会弹出这个错误。此外,即使我查看了其他解决同一主题的类似帖子,我似乎也不了解 keras 中的输入和输出形状。
编辑:这是我的标签的快速浏览:
我应该喜欢一种热编码吗?
解决方案
所以这里的问题是您的标签数据是形状的一维向量(9498,)。这意味着输出标签的数量为 1,这将是一个回归问题,因为您只预测一个值。
另一方面,我们可以看到该向量包含 24 个不同类别的标签,这是您想要对多分类器进行建模的地方。所以你首先要做的就是将labels向量中的每一个值转换成一个24维向量。例如,您可以使用 the 来做到这一点,下面scikit-learn LabelBinarizer
是它的基本工作方式:
from sklearn.preprocessing import LabelBinarizer
binarizer = LabelBinarizer()
labels = binarizer.fit_transform(labels)
print("binarized_labels.shape:", labels.shape) #-> should now return (9498, 24)
现在你可以把它喂给你的模型
这可能是一个完整的解决方案:
from sklearn.preprocessing import LabelBinarizer
import keras
from keras.models import Sequential
from keras.layers import Dense
# Loading the data
X_train, X_test, y_train, y_test = train_test_split(NDVI, labels, test_size=0.15, random_state=42)
# Binarize the output to have 24 class labels
binarizer = LabelBinarizer()
y_train = binarizer.fit_transform(y_train)
y_test = binarizer.transform(y_test) #N.B. here you neeed only to use '.transform()' rather that '.fit_transform()'
# Building the model
model = Sequential([
Dense(128, activation='relu', input_shape=(17,),name="layer1"),
Dense(64, activation='relu', name="layer2"),
Dense(24, activation='softmax', name="layer3"),
])
print(model.summary())
# Compiling the model
model.compile(
optimizer='adam', # gradient-based optimizer
loss='categorical_crossentropy', # (>2 classes)
metrics=['accuracy'],
)
# Training the model
model.fit(
X_train, # training data
y_train, # training targets
epochs=5,
batch_size=32,
)
推荐阅读
- .net-core - .NET Core SqlException:an error occurred during the pre-login handshake
- java - 用两个循环和两个输出语句重写代码
- python - 如何保存然后从数据框中的文件名中提取一些信息
- java - 在 JDBC 连接上设置时区
- java - 我的项目中的错误
- python-3.x - 如何使用 Python 和 Selenium 处理摘要身份验证
- python - 如何检查一个单词是否按字母顺序出现在python中另一个单词之前或之后?
- javascript - 如何在js中找到画布的中心
- floating-point - 32 位浮点数可以精确表示到小数点后的最高值是多少?
- ios - iOS 上的 Firebase 云消息传递