首页 > 解决方案 > 如何在不耗尽 RAM 的情况下将 SMOTE 应用于 tensorflow 数据集

问题描述

我一直在研究一个包含近 17k 图像的不平衡数据集,并且我一直在尝试使用不平衡学习库来实现过采样技术,例如 SMOTE。图像和标签作为张量加载,而不平衡学习库中可用的方法需要 numpy 数组。我已经尝试从 tensorflow 数据集中提取图像,但是在大约 10000 张图像之后,我在 google colab 上的会话由于内存不足而崩溃。我也试图寻找不同的方法,但我找不到其他任何东西。这就是为什么我想知道您是否有任何建议可以真正帮助我克服这个问题。

我按照以下步骤操作:

我使用 tf.keras.preprocessing.image_dataset_from_directory 导入数据。

def create_dataset(folder_path, name, split, seed, shuffle=True):
  return tf.keras.preprocessing.image_dataset_from_directory(
    folder_path, labels='inferred', label_mode='categorical', color_mode='rgb',
    batch_size=32, image_size=(320, 320), shuffle=shuffle, interpolation='bilinear',
    validation_split=split, subset=name, seed=seed)

valid_split = 0.3
train_set = create_dataset(dir_path, 'training', valid_split, 42, shuffle=True).prefetch(1)
valid_set = create_dataset(dir_path, 'validation', valid_split, 42, shuffle=True).prefetch(1)

# output:
# Found 16718 files belonging to 38 classes.
# Using 11703 files for training.
# Found 16718 files belonging to 38 classes.
# Using 5015 files for validation.

然后我运行这行代码以将图像作为 numpy 数组从 tf 数据集中取出,但正如我已经说过的,此时我的会话崩溃了。

X_train = np.concatenate([x for x, y in train_set], axis=0)

谢谢您的支持。

标签: numpytensorflowramimbalanced-datasmote

解决方案


推荐阅读