python - 从 U-NET 获得的预测值有误差
问题描述
我已经构建了一个 U-Net 架构,用于从时间序列数据中预测一个值。
我的 X_train 大小为 (500,1024),Y_train 为 (500) 我的 X_train 大小为 (100,1024),Y_test 为 (100) 当我进行预测时,我得到一个大小为 (1,1024) 的数组。我期望输出有一个值。
我不明白为什么会这样。
def UNetDS64(length, n_channel=1):
x = 64
inputs = Input((length, n_channel))
conv1 = Conv1D(x,3, activation='relu', padding='same')(inputs)
conv1 = BatchNormalization()(conv1)
conv1 = Conv1D(x,3, activation='relu', padding='same')(conv1)
conv1 = BatchNormalization()(conv1)
pool1 = MaxPooling1D(pool_size=2)(conv1)
conv2 = Conv1D(x*2,3, activation='relu', padding='same')(pool1)
conv2 = BatchNormalization()(conv2)
conv2 = Conv1D(x*2,3, activation='relu', padding='same')(conv2)
conv2 = BatchNormalization()(conv2)
pool2 = MaxPooling1D(pool_size=2)(conv2)
conv3 = Conv1D(x*4,3, activation='relu', padding='same')(pool2)
conv3 = BatchNormalization()(conv3)
conv3 = Conv1D(x*4,3, activation='relu', padding='same')(conv3)
conv3 = BatchNormalization()(conv3)
pool3 = MaxPooling1D(pool_size=2)(conv3)
conv4 = Conv1D(x*8,3, activation='relu', padding='same')(pool3)
conv4 = BatchNormalization()(conv4)
conv4 = Conv1D(x*8,3, activation='relu', padding='same')(conv4)
conv4 = BatchNormalization()(conv4)
pool4 = MaxPooling1D(pool_size=2)(conv4)
conv5 = Conv1D(x*16, 3, activation='relu', padding='same')(pool4)
conv5 = BatchNormalization()(conv5)
conv5 = Conv1D(x*16, 3, activation='relu', padding='same')(conv5)
conv5 = BatchNormalization()(conv5)
level4 = Conv1D(1, 1, name="level4")(conv5)
up6 = concatenate([UpSampling1D(size=2)(conv5), conv4], axis=2)
conv6 = Conv1D(x*8, 3, activation='relu', padding='same')(up6)
conv6 = BatchNormalization()(conv6)
conv6 = Conv1D(x*8, 3, activation='relu', padding='same')(conv6)
conv6 = BatchNormalization()(conv6)
level3 = Conv1D(1, 1, name="level3")(conv6)
up7 = concatenate([UpSampling1D(size=2)(conv6), conv3], axis=2)
conv7 = Conv1D(x*4, 3, activation='relu', padding='same')(up7)
conv7 = BatchNormalization()(conv7)
conv7 = Conv1D(x*4, 3, activation='relu', padding='same')(conv7)
conv7 = BatchNormalization()(conv7)
level2 = Conv1D(1, 1, name="level2")(conv7)
up8 = concatenate([UpSampling1D(size=2)(conv7), conv2], axis=2)
conv8 = Conv1D(x*2, 3, activation='relu', padding='same')(up8)
conv8 = BatchNormalization()(conv8)
conv8 = Conv1D(x*2, 3, activation='relu', padding='same')(conv8)
conv8 = BatchNormalization()(conv8)
level1 = Conv1D(1, 1, name="level1")(conv8)
up9 = concatenate([UpSampling1D(size=2)(conv8), conv1], axis=2)
conv9 = Conv1D(x, 3, activation='relu', padding='same')(up9)
conv9 = BatchNormalization()(conv9)
conv9 = Conv1D(x, 3, activation='relu', padding='same')(conv9)
conv9 = BatchNormalization()(conv9)
out = Dense(1,name="out")(conv9)
model = Model(inputs=[inputs], outputs=[out])
return model
model_dict['UNetDS64'] = UNetDS64
#model_dict['MultiResUNet1D'] = MultiResUNet1D
mdlName1 = 'UNetDS64' # approximation network
mdlName2 = 'MultiResUNet1D' # refinement network
length = 1024 # length of the signal
try: # create directory to save models
os.makedirs('models')
except:
pass
try: # create directory to save training history
os.makedirs('History')
except:
pass
# 10 fold cross validation
for foldname in range(1):
print('----------------')
print('Training Fold {}'.format(foldname+1))
print('----------------')
# loading training data
#Y_train = prepareLabel(Y_train) # prepare labels for training deep supervision
#Y_val = prepareLabel(Y_val) # prepare labels for training deep supervision
mdl1 = model_dict[mdlName1](length) # create approximation network
# loss = mae, with deep supervision weights
mdl1.compile(loss='mean_absolute_error',optimizer='adam',metrics=['mean_squared_error'])
checkpoint1_ = ModelCheckpoint(os.path.join('models','{}_model1_fold{}.h5'.format(mdlName1,foldname)), verbose=1, monitor='val_loss',save_best_only=True, mode='auto')
# train approximation network for 100 epochs
history1 = mdl1.fit(norm_x_train,norm_y_train,epochs=10,batch_size=32,validation_data=(norm_x_test,norm_y_test),callbacks=[checkpoint1_],verbose=1)
解决方案
我猜正在发生的事情是 x 火车上的 500 和 y 火车上的 500 被分开,留下 1, 1024,因为 x 火车上有 1024。