tensorflow - 如何在 Keras 中使用 Convolutional LSTM 实现位置估计?
问题描述
我在 Keras 和 Tensorflow 中使用了 LSTM。
我想实现位置估计。
我想输入电影(1个场景是15帧)并估计移动电影中方块的位置。
输入为 15 帧。输出是 2 个变量 (x, y)。
在下面的代码中,估计精度太差了。我应该怎么办?而且,我不明白AveragePooling3D/Reshape(没有这个它不会执行。)。
# We create a layer which take as input movies of shape
# (n_frames, width, height, channels) and returns a movie
# of identical shape.
seq = Sequential()
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
input_shape=(None, 80, 80, 1),
padding='same', return_sequences=True))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
seq.add(BatchNormalization())
seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
seq.add(BatchNormalization())
#seq.add(Flatten())
seq.add(AveragePooling3D((1, 80, 80)))
seq.add(Reshape((-1, 40)))
seq.add(Dense(2))
#seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
# activation='sigmoid',
# padding='same', data_format='channels_last'))
seq.compile(loss='mean_squared_error', optimizer='adam')
def generate_movies(n_samples=1200, n_frames=15):
row = 80
col = 80
noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)
shifted_movies = np.zeros((n_samples, n_frames, row, col, 1),
dtype=np.float)
square_x_y = np.zeros((n_samples, n_frames, 2), dtype=np.float)
for i in range(n_samples):
for j in range(1):
# Initial position
xstart = np.random.randint(20, 60)
ystart = np.random.randint(20, 60)
# Direction of motion
directionx = np.random.randint(0, 3) - 1
directiony = np.random.randint(0, 3) - 1
# Size of the square
w = np.random.randint(2, 4)
for t in range(n_frames):
x_shift = xstart + directionx * t
y_shift = ystart + directiony * t
noisy_movies[i, t, x_shift - w: x_shift + w,
y_shift - w: y_shift + w, 0] += 1
# Make it more robust by adding noise.
# The idea is that if during inference,
# the value of the pixel is not exactly one,
# we need to train the network to be robust and still
# consider it as a pixel belonging to a square.
if np.random.randint(0, 2):
noise_f = (-1)**np.random.randint(0, 2)
noisy_movies[i, t,
x_shift - w - 1: x_shift + w + 1,
y_shift - w - 1: y_shift + w + 1,
0] += noise_f * 0.1
# Shift the ground truth by 1
x_shift = xstart + directionx * (t + 1)
y_shift = ystart + directiony * (t + 1)
shifted_movies[i, t, x_shift - w: x_shift + w,
y_shift - w: y_shift + w, 0] += 1
square_x_y[i, t, 0] = x_shift/row
square_x_y[i, t, 1] = y_shift/col
# Cut to a 40x40 window
#noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::]
#shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::]
#noisy_movies[noisy_movies >= 1] = 1
#shifted_movies[shifted_movies >= 1] = 1
return noisy_movies, shifted_movies, square_x_y
# Train the network
noisy_movies, shifted_movies, sq_x_y = generate_movies(n_samples = 1200)
seq.fit(noisy_movies[:1000], sq_x_y[:1000], batch_size=10,
epochs=1, validation_split=0.05)
解决方案
看你seq.summary()
了解形状。
你的最后ConvLSTM2D
一层正在输出形状的东西
(movies, length, side1, side2, channels)
(None, None, 80, 80, 40)
要使 Dense 层起作用,您需要保留“movies”和“length”维度,并将其他维度合并为一个。
请注意,维度40
(通道)很重要,因为它代表不同的“概念”,而80
s(边)是纯粹的 2D 位置。
由于您不想触摸“长度”维度,因此您需要一个AveragePooling2D
(不是 3D)。但是由于您要检测一个非常明显的位置特征及其位置,我建议您根本不要折叠空间维度。最好只是重塑并添加一个Dense
考虑到这些位置的。
因此,而不是AveragePooling
我建议你使用:
seq.add(Reshape((-1, 80*80*40)) #new shape is = (movies, length, 80*80*40)
然后添加一个仍然具有位置概念的 Dense 层:
seq.add(Dense(2)) #output shape is (movies, length, 2)
这不能保证让你的模型表现良好,但我相信“更好”。
您可以做的其他事情是添加更密集的层,并顺利地从 80*80*40 特征变为 2。
这更复杂,但您可能会研究函数式 API 模型并了解“U-net”并尝试在那里制作类似 u-net 的东西。
推荐阅读
- postgresql - 如果它说我没有权限这样做,我该如何在 postgresql 中使用 COPY?
- javascript - 在 ngAfterViewInit 中启动的带有 CanvasJS 图表的 Angular
- python-3.x - 动态创建小部件的滚动条 - Python Tkinter
- lottie - Lottie 本地动画未显示在网页上
- python - Balanced_accuracy 不是 scikit-learn 中的有效评分值
- docusignapi - DocuSign CreateTemplate API 未返回已创建模板的 TemplateId
- sql - Snowflake SQL 编译错误:位置 XX 处的语法错误行 XX 意外 '('
- matlab - 如何在 MATLAB 中以编程方式将命令发送到命令窗口?
- python - Django 请求有时不包含通过 SSL 的会话
- dataframe - 通过填充现有列在 Pyspark Dataframe 中创建新列