python - 生成器函数总是两次产生相同的东西
问题描述
我有这个生成器函数,我希望用它来筛选图像数据集。这些图像位于批处理大小为 16 的 PyTorch 数据加载器中。我循环遍历数据加载器以抓取批处理(16 个图像),然后循环遍历批处理以抓取图像。
我想要做的是将图像标签存储在图像yield
中,同时在图中绘制 16 个图像。所以我想做images = next(show_batch(dataloader, labels_dataframe, nrows, ncols))
,每次我得到 16 个存储在图像中的图像标签和 16 个图像的图。通过这种方式,我可以识别不良图像并准备好从我的数据集中丢弃它们的标签。该代码不断生成相同的(第一个)16 个图像两次。我怀疑这与每次都创建一个新列表有关,所以我要重新启动发电机?
为什么代码会连续两次生成相同的 16 张图像,如何修复它以一次生成 16 张图像,同时存储从labels_dataframe
图像中获取的标签?
def show_batch(dataloader, labels_dataframe, nrows, ncols):
fig = plt.figure(figsize=(30,15))
for i, batch in enumerate(train_dl):
images = []
for j, image in enumerate(batch['image']):
ax = fig.add_subplot(nrows, ncols, j+1)
ax.imshow(image.permute(2, 1, 0))
images.append(labels_dataframe.loc[i*16+j, 'id_code'])
yield images
解决方案
Marco Bonelli 指出我应该展示我是如何调用生成器的。当我这样做时,我发现我做错了什么,并修复了我如何调用它和函数。
我在打电话next(show_batch(dataloader, labels_dataframe, nrows, ncols))
,所以每次都在调用一个新的生成器函数。我没有制作生成器对象。
然后当我创建了一个生成器对象并开始调用它时,它只显示了前 16 张图像,然后只产生了之后的标签,所以我将图形对象移动到每个批处理循环内。修改后的代码以及我如何称呼它:
def show_batch(dataloader, labels_dataframe, nrows, ncols):
for i, batch in enumerate(train_dl):
fig = plt.figure(figsize=(30,15))
images = []
for j, image in enumerate(batch['image']):
ax = fig.add_subplot(nrows, ncols, j+1)
ax.imshow(image.permute(2, 1, 0))
images.append(labels_dataframe.loc[i*16+j, 'id_code'])
yield images
sample_images = show_batch(dataloader, labels_dataframe, nrows, ncols)
next(sample_images)
谢谢,马可。
推荐阅读
- c++ - std::make_shared 接口?
- postgresql - PostgreSql 过程结构
- r - 使用 gtsummary 中因变量的值命名模型
- python -
对象返回类型错误:+ 的不支持的操作数类型:“int”和“str” - python - VSCode 使用可选参数自动完成
- python - *.grid() 布局管理器无法按我预期的那样工作
- azure - 实施基于 Power BI 和 Azure Synapse 的现代端到端报告系统
- python - 带条件的 Dajngo 模型外场
- protocols - 如何通过 ms-rd 协议启动服务器
- excel - 将系列动态添加到图表