tensorflow - Keras:空间金字塔池不起作用,训练中的问题
问题描述
我们从这篇论文https://arxiv.org/pdf/1406.4729v4.pdf开始研究空间金字塔池。我们实现了这段代码,但它不能正常工作(损失没有减少),可能我们在训练方面遗漏了一些东西,你能帮助我们吗?
这是 SPP 层
from keras.engine.topology import Layer
import keras.backend as K
import numpy as np from keras.models import Sequential
from keras.layers import Convolution2D, Activation, MaxPooling2D, Dense
class SpatialPyramidPooling(Layer):
"""Spatial pyramid pooling layer for 2D inputs.
See Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition,
K. He, X. Zhang, S. Ren, J. Sun
# Arguments
pool_list: list of int
List of pooling regions to use. The length of the list is the number of pooling regions,
each int in the list is the number of regions in that pool. For example [1,2,4] would be 3
regions with 1, 2x2 and 4x4 max pools, so 21 outputs per feature map
# Input shape
4D tensor with shape:
`(samples, channels, rows, cols)` if dim_ordering='th'
or 4D tensor with shape:
`(samples, rows, cols, channels)` if dim_ordering='tf'.
# Output shape
2D tensor with shape:
`(samples, channels * sum([i * i for i in pool_list])`
def __init__(self, pool_list, **kwargs):
self.dim_ordering = K.image_data_format()
assert self.dim_ordering in {'channels_last', 'channels_first'}, 'dim_ordering must be in {channels_last, channels_first}'
self.pool_list = pool_list
self.num_outputs_per_channel = sum([i * i for i in pool_list])
super(SpatialPyramidPooling, self).__init__(**kwargs)
def build(self, input_shape):
if self.dim_ordering == 'channels_first':
self.nb_channels = input_shape[1]
elif self.dim_ordering == 'channels_last':
self.nb_channels = input_shape[3]
def compute_output_shape(self, input_shape):
return (input_shape[0], self.nb_channels * self.num_outputs_per_channel)
def get_config(self):
config = {'pool_list': self.pool_list}
base_config = super(SpatialPyramidPooling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, x, mask=None):
input_shape = K.shape(x)
if self.dim_ordering == 'channels_first':
num_rows = input_shape[2]
num_cols = input_shape[3]
elif self.dim_ordering == 'channels_last':
num_rows = input_shape[1]
num_cols = input_shape[2]
row_length = [K.cast(num_rows, 'float32') / i for i in self.pool_list]
col_length = [K.cast(num_cols, 'float32') / i for i in self.pool_list]
outputs = []
if self.dim_ordering == 'channels_first':
for pool_num, num_pool_regions in enumerate(self.pool_list):
for jy in range(num_pool_regions):
for ix in range(num_pool_regions):
x1 = ix * col_length[pool_num]
x2 = ix * col_length[pool_num] + col_length[pool_num]
y1 = jy * row_length[pool_num]
y2 = jy * row_length[pool_num] + row_length[pool_num]
x1 = K.cast(K.round(x1), 'int32')
x2 = K.cast(K.round(x2), 'int32')
y1 = K.cast(K.round(y1), 'int32')
y2 = K.cast(K.round(y2), 'int32')
new_shape = [input_shape[0], input_shape[1],
y2 - y1, x2 - x1]
x_crop = x[:, :, y1:y2, x1:x2]
xm = K.reshape(x_crop, new_shape)
pooled_val = K.max(xm, axis=(2, 3))
outputs.append(pooled_val)
elif self.dim_ordering == 'channels_last':
for pool_num, num_pool_regions in enumerate(self.pool_list):
for jy in range(num_pool_regions):
for ix in range(num_pool_regions):
x1 = ix * col_length[pool_num]
x2 = ix * col_length[pool_num] + col_length[pool_num]
y1 = jy * row_length[pool_num]
y2 = jy * row_length[pool_num] + row_length[pool_num]
x1 = K.cast(K.round(x1), 'int32')
x2 = K.cast(K.round(x2), 'int32')
y1 = K.cast(K.round(y1), 'int32')
y2 = K.cast(K.round(y2), 'int32')
new_shape = [input_shape[0], y2 - y1,
x2 - x1, input_shape[3]]
x_crop = x[:, y1:y2, x1:x2, :]
xm = K.reshape(x_crop, new_shape)
pooled_val = K.max(xm, axis=(1, 2))
outputs.append(pooled_val)
if self.dim_ordering == 'channels_first':
outputs = K.concatenate(outputs)
elif self.dim_ordering == 'channels_last':
#outputs = K.concatenate(outputs,axis = 1)
#outputs = K.concatenate(outputs)
#outputs = K.reshape(outputs,(len(self.pool_list),self.num_outputs_per_channel,input_shape[0],input_shape[1]))
#outputs = K.permute_dimensions(outputs,(3,1,0,2))
outputs = K.reshape(outputs,(input_shape[0], self.num_outputs_per_channel * self.nb_channels))
return outputs
这是模型
batch_size = 32
num_channels = 3
num_classes = 5
model = Sequential()
# uses theano ordering. Note that we leave the image size as None to allow multiple image sizes
model.add(Convolution2D(32, (3, 3), padding='same', input_shape=(None, None,3)))
model.add(Activation('relu'))
model.add(Convolution2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Convolution2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(SpatialPyramidPooling([1, 2, 4]))
print(model.output_shape)
print(model.summary())
model.add(Dense(num_classes))
model.add(Activation('softmax'))
print(model.output_shape)
print(model.summary())
#model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
这是模型的数据上传和拟合
batch_size = 32
img_height = 180
img_width = 180
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
fname='flower_photos',
untar=True)
data_dir = pathlib.Path(data_dir)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
from tensorflow.keras import layers
normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
print(train_ds)
#MODEL FIT
model.fit(
train_ds,
validation_data=val_ds,
epochs=10
)
解决方案
推荐阅读
- javascript - KeystoneJS updateItems , keystone.updateItem is not a function TypeError: keystone.updateItem is not a function
- c - 我可以比较函数指针和函数的相等性吗?
- ansible - 需要语法才能将 Ansible 元模块添加到现有 Playbook
- docker - 将框架 Gin 切换到 Echo 后服务器不再响应
- c++ - 在菜单栏中实现 QAction
- fluentd - FluentD,如何仅 grep 特定日志
- php - 如何将外部表列分组到不同的表下?
- r - R中的条件累积和时间序列列
- html - 前景文本与背景具有相同的过滤器
- android - 如何在 android 中强制设置“ImageView”字段?