python - 使用 ImageDataGenerator 对 Keras 中的视频(4D 张量)进行数据增强
问题描述
我ImageDataGenerator
在 Keras 中有一个,我想在训练期间将其应用于短视频剪辑中的每一帧,这些短视频剪辑表示为形状为(num_frames、width、height、3)的 4D numpy 数组。
对于由每个具有形状(宽度、高度、3)的图像组成的标准数据集,我通常会执行以下操作:
aug = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=15,
zoom_range=0.15)
model.fit_generator(
aug.flow(X_train, y_train),
epochs=100)
如何将这些相同的数据增强应用于表示图像序列的 4D numpy 数组的数据集?
解决方案
我想到了。我创建了一个继承自 tensorflow.keras.utils.Sequence 的自定义类,该类使用 scipy 对每个图像执行增强。
class CustomDataset(tf.keras.utils.Sequence):
def __init__(self, batch_size, *args, **kwargs):
self.batch_size = batch_size
self.X_train = args[0]
self.Y_train = args[1]
def __len__(self):
# returns the number of batches
return int(self.X_train.shape[0] / self.batch_size)
def __getitem__(self, index):
# returns one batch
X = []
y = []
for i in range(self.batch_size):
r = random.randint(0, self.X_train.shape[0] - 1)
next_x = self.X_train[r]
next_y = self.Y_train[r]
augmented_next_x = []
###
### Augmentation parameters for this clip.
###
rotation_amt = random.randint(-45, 45)
for j in range(self.X_train.shape[1]):
transformed_img = ndimage.rotate(next_x[j], rotation_amt, reshape=False)
transformed_img[transformed_img == 0] = 255
augmented_next_x.append(transformed_img)
X.append(augmented_next_x)
y.append(next_y)
X = np.array(X).astype('uint8')
y = np.array(y)
encoder = LabelBinarizer()
y = encoder.fit_transform(y)
return X, y
def on_epoch_end(self):
# option method to run some logic at the end of each epoch: e.g. reshuffling
pass
然后我将它传递给fit_generator
方法:
training_data_augmentation = CustomDataset(BS, X_train_L, y_train_L)
model.fit_generator(
training_data_augmentation,
epochs=300)
推荐阅读
- nginx - 让 nginx 在返回页面前等待 x 秒
- swiftui - 如何在循环中更新状态时显示所有中间视图,swiftUI
- .net - .Net Framework 4.7 调用 .net standard 2.0 dll 但出现以下错误
- python - 为什么我不能从 Python 调用 SSH 终端命令?
- .net - 我可以从保存的 XML 文件构建 OData 模型吗?
- javascript - 所有猫鼬模型方法都基于承诺吗?
- python - MongoDB在python中更新对象数组
- python - 如何运行可执行文件,然后在它停止后继续?
- python - 自动绘图的功能。各种子图和各种数字
- php - 在 MAC OS X 上启用 ZTS 重新编译 PHP