machine-learning - CNN 不兼容
问题描述
我的数据具有以下形状:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)
(942, 32, 32, 1) (236, 32, 32, 1) (942, 3, 3) (236, 3, 3)
每当我尝试运行我的 CNN 时,我都会收到以下错误:
from tensorflow.keras import layers
from tensorflow.keras import Model
img_input = layers.Input(shape=(32, 32, 1))
x = layers.Conv2D(16, (3,3), activation='relu', strides = 1, padding = 'same')(img_input)
x = layers.Conv2D(32, (3,3), activation='relu', strides = 2)(x)
x = layers.Conv2D(128, (3,3), activation='relu', strides = 2)(x)
x = layers.MaxPool2D(pool_size=2)(x)
x = layers.Conv2D(3, 3, activation='linear', strides = 2)(x)
output = layers.Flatten()(x)
model = Model(img_input, output)
model.summary()
model.compile(loss='mean_squared_error',optimizer= 'adam', metrics=['mse'])
history = model.fit(X_train,Y_train,validation_data=(X_test, Y_test), epochs = 100,verbose=1)
错误:
InvalidArgumentError: Incompatible shapes: [32,3] vs. [32,3,3]
[[node BroadcastGradientArgs_2 (defined at /usr/local/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_distributed_function_7567]
Function call stack:
distributed_function
我在这里想念什么?
解决方案
您没有正确处理网络内部的维度。首先扩展您的 y 的维度,以便以这种格式 (n_sample, 3, 3, 1) 获取它们。此时调整网络(我去掉 flatten 和 max pooling 并调整最后一个 conv 输出)
# create dummy data
n_sample = 10
X = np.random.uniform(0,1, (n_sample, 32, 32, 1))
y = np.random.uniform(0,1, (n_sample, 3, 3))
# expand y dim
y = y[...,np.newaxis]
print(X.shape, y.shape)
img_input = Input(shape=(32, 32, 1))
x = Conv2D(16, (3,3), activation='relu', strides = 1, padding = 'same')(img_input)
x = Conv2D(32, (3,3), activation='relu', strides = 2)(x)
x = Conv2D(128, (3,3), activation='relu', strides = 2)(x)
x = Conv2D(1, (3,3), activation='linear', strides = 2)(x)
model = Model(img_input, x)
model.summary()
model.compile(loss='mean_squared_error',optimizer= 'adam', metrics=['mse'])
model.fit(X,y, epochs=3)
推荐阅读
- postgresql - Postgres 副本不同步
- c++ - 有人可以解释一下 for(;Q.size();) 的作用吗?
- django - 在 Django(Navbar.html、general.css 等)中将全局 html/css 文件放在哪里,以便所有应用程序都可以引用它?
- python - 使用 python 请求的 Yelp API 调用
- python - Python - 即使在正确的目录中也无法导入
- python - 在 Python 中从 XML 文件中提取注释
- java - 如何将数据从 android java 发送到 azure 事件中心
- c# - 动作字典 - 包括动作参数
- python - 如何绘制描述 ML 模型性能的图表?
- elasticsearch - 每次在 elasticsearch 中更新文档时自动增加字段值