python - 如何更改pytorch数据文件夹中的标签?
问题描述
我首先加载一个未标记的数据集,如下所示:
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
现在因为我正在尝试进行半监督学习:我正在尝试定义以下功能。输入“数据集”是我刚刚加载的 unlabeled_set。
由于我想将数据集的标签更改为我预测的标签,而不是原始标签(所有原始标签都是 1),我该怎么办?
我曾尝试使用 dataset.targets 更改标签,但它根本不起作用。以下是我的功能:
import torch
def get_pseudo_labels(dataset, model, threshold=0.07):
# This functions generates pseudo-labels of a dataset using given model.
# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
# You are NOT allowed to use any models trained on external data for pseudo-labeling.
device = "cuda" if torch.cuda.is_available() else "cpu"
x = []
y = []
# print(dataset.targets[0])
# Construct a data loader.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax()
counter = 0
# Iterate over the dataset by batches.
for batch in tqdm(data_loader):
img, _ = batch
# Forward the data
# Using torch.no_grad() accelerates the forward process.
with torch.no_grad():
logits = model(img.to(device))
# Obtain the probability distributions by applying softmax on logits.
probs = softmax(logits)
count = 0
# ---------- TODO ----------
# Filter the data and construct a new dataset.
dataset.targets = torch.tensor(dataset.targets)
for p in probs:
if torch.max(p) >= threshold:
if not(counter in x):
x.append(counter)
dataset.targets[counter] = torch.argmax(p)
counter += 1
# Turn off the eval mode.
model.train()
# dat = DataLoader(ImgDataset(x,y), batch_size=batch_size, shuffle=False)
print(dataset.targets[10])
new = torch.utils.data.Subset(dataset, x)
return new```
解决方案
PyTorch 数据集可以返回值的元组,但它们没有固有的“特征”/“目标”区别。您可以像这样创建修改后的 DataSet:
labeled_data = [*zip(dataset, labels)]
data_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=False)
for imgs, labels in data_loader: # per batch
...
推荐阅读
- javascript - 根据两个属性返回数组中的目标对象,一个最小值/最大值和一个布尔值
- css - 汉堡下拉css显示属性
- flutter - 从 DataTable 创建 pdf
- javascript - 如何在 Vuejs 和 Expressjs 中上传文件
- r - 如何基于现有数据集创建新数据集
- javascript - 无效的正则表达式:“/^[+]?[0-9]{0,1}[-.]?\(?([0-9]{3})\)?[-.]?([0 -9]{3})[-. ]?([0-9]{4})$/gm" 在 javscript
- image - 第一个项目问题图像未显示在 github 上
- python - Python数据框将时间日期'SylmiSeb'(2018-12-31 23:43:02+00:00)转换为日期时间
- webhooks - 是否有适用于 Google 我的商家 (GMB) 的 Webhook?
- python-3.x - 在字典Python3中计算值