python - keras 输入形状:输入与图层不兼容
问题描述
我看过一些类似的问题,但我仍然不明白如何解决我的问题。
我正在尝试构建一个 CNN,它根据探测器随时间释放的能量的示波器轨迹来估计有多少粒子撞击探测器。
我有 1024 个时间样本的 100,000 个事件,我将它们分成 80/20 作为训练/测试,如下所示:
from sklearn.model_selection import train_test_split
train_to_test_ratio=0.8 #proportion of the dataset to include in the train split
X_train,X_test,Y_train,Y_test=train_test_split(NormSignals,labels,train_size=train_to_test_ratio)
no_outputs = 14 # maximum number of particles expected
# force the labels to have 14 binary digits, one for each of the possible outputs
Y_train=tf.one_hot(Y_train,no_outputs)
Y_test=tf.one_hot(Y_test,no_outputs)
当我尝试为网络定义输入形状时,我会这样做(下面的完整 CNN 代码):
# Define input to neural network (tensors of 1024 time samples x 1 amplitude per sample)
inputs = keras.Input(shape=(1024,1))
但它给了我错误:“Conv_1 层的输入 0 与该层不兼容:预期 ndim=4,发现 ndim=3。收到完整形状:[None, 1024, 1]”
我认为输入形状与传递给网络的数据数组的形状一样简单。有人可以解释我的数据的正确形状应该是什么吗?
非常感谢您!
完整的CNN:
from tensorflow import keras
# Following the architecture of the CNN from the image recognition lab (14/5/2020):
# Simple CNN:
class noiseLayer(keras.layers.Layer):
def __init__(self,mean):
super(noiseLayer, self).__init__()
self.mean = mean
def call(self, input):
mean = self.mean
return input + (np.random.poisson(mean))/mean
# Add data augmentation to produce a random flip of the data (the ECal is symmetrical)
# and add poissonian noise to all of the crystals - using large N and dividing by N normalises
# the noise to be approximately continuous between 0 and 1
data_augmentation = keras.Sequential([
noiseLayer(mean = 1000)
], name='DataAugm')
# Define input to neural network (tensors of 1024 time samples x 1 amplitude per sample)
inputs = keras.Input(shape=(1024,1))
#x=inputs
x = data_augmentation(inputs)
# primo blocco Convoluzionale
x = keras.layers.Conv2D(16, kernel_size=(3,3), name='Conv_1')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool2D((2,2), name='MaxPool_1')(x)
# secondo blocco Convoluzionale
x = keras.layers.Conv2D(16, kernel_size=(3,3), name='Conv_2')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool2D((2,2), name='MaxPool_2')(x)
# terzo blocco convoluzionale
x = keras.layers.Conv2D(32, kernel_size=(3,3), name='Conv_3')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool2D((2,2), name='MaxPool_3')(x)
# Flatten output tensor of the last convolutional layer so it can be used as
# input to the dense layers
x = keras.layers.Flatten(name='Flatten')(x)
# dense network: 2 dense hidden layer with 256 neurons, with ReLU activation
# Classifier
x = keras.layers.Dense(64, name='Dense_1')(x)
x = keras.layers.ReLU(name='ReLU_dense_1')(x)
#x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(64, name='Dense_2')(x)
x = keras.layers.ReLU(name='ReLU_dense_2')(x)
outputs = keras.layers.Dense(no_outputs, activation='softmax', name='Output')(x)
# Model definition
model = keras.Model(inputs=inputs, outputs=outputs, name='VGGlike_CNN')
# Print model summary
model.summary()
# Show model structure
keras.utils.plot_model(model, show_shapes=True)
解决方案
问题是我正在使用 2D 图层来尝试解决 1D 问题。
现在将所有 2D 图层更改为 1D 编译不会出错:
x = keras.layers.Conv1D(16, kernel_size=(3), name='Conv_1')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool1D((2), name='MaxPool_1')(x)
# secondo blocco Convoluzionale
x = keras.layers.Conv1D(16, kernel_size=(3), name='Conv_2')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool1D((2), name='MaxPool_2')(x)
# terzo blocco convoluzionale
x = keras.layers.Conv1D(32, kernel_size=(3), name='Conv_3')(x)
x = keras.layers.LeakyReLU(0.1)(x)
x = keras.layers.MaxPool1D((2), name='MaxPool_3')(x)
# Flatten output tensor of the last convolutional layer so it can be used as
# input to the dense layers
x = keras.layers.Flatten(name='Flatten')(x)
# dense network: 2 dense hidden layer with 256 neurons, with ReLU activation
# Classifier
x = keras.layers.Dense(64, name='Dense_1')(x)
x = keras.layers.ReLU(name='ReLU_dense_1')(x)
#x = keras.layers.Dropout(0.2)(x)
x = keras.layers.Dense(64, name='Dense_2')(x)
x = keras.layers.ReLU(name='ReLU_dense_2')(x)
推荐阅读
- android-studio - 重复类 android.support
- django - 以 django 格式附加邮件错误:不是 pdf 或损坏
- mysql - 这个sql函数的正确版本是什么?
- joomla3.0 - 使用批处理时生成多个重复内容
- swift - 将接受闭包的函数转换为接受函数
- ios - react-native run-ios 失败 - 找不到模拟器
- html - 将普通复选框更改为自定义复选框,并使它们表现为单选按钮
- android - 使用 Android Studio/Kotlin 从 URL 到 textView 的文本崩溃应用程序
- c++ - string += s1 和 string = string + s1 之间的区别
- php - 我如何在 php 上正确使用 else 和 if ?