tensorflow - memmap 数组到 pytorch 和梯度累积
问题描述
处理后我有一个大型数据集(> 62 GiB),保存为两个 NumPy.memmap 数组,其中一个数据和另一个用于标签,数据集具有这些形状 (7390,60,224,224,3) 和 (7390) 并且没有洗牌所以我需要先洗牌。
现在我使用 tensorflow2 并将这段代码与我的生成器一起使用来管理 memmap 数组
def my_generator():
for i in range(len(numpy_array)):
yield numpy_array[i,:,:,:,:],np.array(labels[i]).reshape(1)
full_dataset = tf.data.Dataset.from_generator(
generator=my_generator,
output_types=(np.uint8,np.int32),
output_shapes=((60,224,224,3),(1))
)
full_dataset = full_dataset.shuffle(SHUFFLE_BUFFER_SIZE, reshuffle_each_iteration=False)
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(test_size)
test_dataset = test_dataset.take(test_size)
这样我就可以在不通过洗牌和批处理将整个数据集加载到内存的情况下进行训练。
现在使用这个当前的模型和数据集,vram 不足以加载超过 2 个批次作为张量。而且我无法使用 2 的批量进行训练。
我想到了梯度累积,但我不能用 TF2 来做,我发现用 pytorch 很容易,但我找不到如何处理 memmap 数组,就像在 tensorflow 中使用生成器一样。
所以我需要知道如何从 pytorch 加载数据集,并在 pytorch 中进行相同的改组和批处理。
或者如果有人在 TF2 上有一个现成的GA代码
解决方案
我将只解决洗牌问题。
不要使用 tf.data.Dataset 进行改组,而是在生成器级别进行。这应该有效:
class Generator(object):
def __init__(self, images, labels, batch_size):
self.images = images
self.labels = labels
self.batch_size = batch_size
self.idxs = np.arange(len(self.images))
self.on_epoch_end()
def on_epoch_end(self):
# Shuffle the indices
np.random.shuffle(self.idxs)
def generator(self):
i = 0
while i < len(self.idxs):
idx = self.idxs[i]
yield (self.images[idx], self.labels[i])
i += 1
self.on_epoch_end()
def batch_generator(self):
it = iter(self.generator)
while True:
vals = [next(it) for i in range(self.batch_size)]
images, labels = zip(*vals)
yield images, labels
然后你可以使用它
gen = Generator(...)
it = iter(gen)
batch = next(it) # Call this every time you want a new batch
我确信 pytorch 已经为这类东西内置了方法
推荐阅读
- c++ - GCC 中 -faligned-new 的值
- python - Python resource_tracker:在清晰的环境中使用多处理“spawn”方法时进程意外死亡
- python - 如何在python中访问嵌套字典。我如何访问 total_cases
- sql-server - 用于开发的命名实例的优缺点
- python - Python使用数组两次导致错误的返回值
- java - 什么是 github 上的“BIN”文件?
- r - 仅在 R 中更改一个会话的默认库路径
- discord - 尝试启动我的不和谐机器人时不断收到错误消息
- google-app-engine - 在 gcloud 上部署 nestjs 应用程序时出错
- pycharm - pycharm pytest 在交互模式下使用 bash 挂起 python subprocess.run()