首页 > 解决方案 > 生成器函数总是两次产生相同的东西

问题描述

我有这个生成器函数,我希望用它来筛选图像数据集。这些图像位于批处理大小为 16 的 PyTorch 数据加载器中。我循环遍历数据加载器以抓取批处理(16 个图像),然后循环遍历批处理以抓取图像。

我想要做的是将图像标签存储在图像yield中,同时在图中绘制 16 个图像。所以我想做images = next(show_batch(dataloader, labels_dataframe, nrows, ncols)),每次我得到 16 个存储在图像中的图像标签和 16 个图像的图。通过这种方式,我可以识别不良图像并准备好从我的数据集中丢弃它们的标签。该代码不断生成相同的(第一个)16 个图像两次。我怀疑这与每次都创建一个新列表有关,所以我要重新启动发电机?

为什么代码会连续两次生成相同的 16 张图像,如何修复它以一次生成 16 张图像,同时存储从labels_dataframe图像中获取的标签?

def show_batch(dataloader, labels_dataframe, nrows, ncols):
    fig = plt.figure(figsize=(30,15))
    for i, batch in enumerate(train_dl):
        images = []
        for j, image in enumerate(batch['image']):
            ax = fig.add_subplot(nrows, ncols, j+1)
            ax.imshow(image.permute(2, 1, 0))
            images.append(labels_dataframe.loc[i*16+j, 'id_code'])
        yield images

标签: pythongeneratoryield

解决方案


Marco Bonelli 指出我应该展示我是如何调用生成器的。当我这样做时,我发现我做错了什么,并修复了我如何调用它和函数。

我在打电话next(show_batch(dataloader, labels_dataframe, nrows, ncols)),所以每次都在调用一个新的生成器函数。我没有制作生成器对象。

然后当我创建了一个生成器对象并开始调用它时,它只显示了前 16 张图像,然后只产生了之后的标签,所以我将图形对象移动到每个批处理循环内。修改后的代码以及我如何称呼它:

def show_batch(dataloader, labels_dataframe, nrows, ncols):
    for i, batch in enumerate(train_dl):
        fig = plt.figure(figsize=(30,15))
        images = []
        for j, image in enumerate(batch['image']):
            ax = fig.add_subplot(nrows, ncols, j+1)
            ax.imshow(image.permute(2, 1, 0))
            images.append(labels_dataframe.loc[i*16+j, 'id_code'])
        yield images

sample_images = show_batch(dataloader, labels_dataframe, nrows, ncols)
next(sample_images)

谢谢,马可。


推荐阅读