python - 枚举数据加载器时无法访问所有数据
问题描述
我定义了一个 customDataset
和一个 custom Dataloader
,我想使用for i,batch in enumerate(loader)
. 但是这个 for 循环在每个时期给了我不同数量的批次,而且它们都远小于实际的批次数量(等于number_of_samples/batch_size
)。
这是我定义数据集和数据加载器的方式:
class UsptoDataset(Dataset):
def __init__(self, csv_file):
df = pd.read_csv(csv_file)
self.rea_trees = df['reactants_trees'].to_numpy()
self.syn_trees = df['synthons_trees'].to_numpy()
self.syn_smiles = df['synthons'].to_numpy()
self.product_smiles = df['product'].to_numpy()
def __len__(self):
return len(self.rea_trees)
def __getitem__(self, item):
rea_tree = self.rea_trees[item]
syn_tree = self.syn_trees[item]
syn_smile = self.syn_smiles[item]
pro_smile = self.product_smiles[item]
# omit the snippet used to process the data here, which gives us the variables used in the return statement.
return {'input_words': input_words,
'input_chars': input_chars,
'syn_tree_indices': syn_tree_indices,
'syn_rule_nl_left': syn_rule_nl_left,
'syn_rule_nl_right': syn_rule_nl_right,
'rea_tree_indices': rea_tree_indices,
'rea_rule_nl_left': rea_rule_nl_left,
'rea_rule_nl_right': rea_rule_nl_right,
'class_mask': class_mask,
'query_paths': query_paths,
'labels': labels,
'parent_matrix': parent_matrix,
'syn_parent_matrix': syn_parent_matrix,
'path_lens': path_lens,
'syn_path_lens': syn_path_lens}
@staticmethod
def collate_fn(batch):
input_words = torch.tensor(np.stack([_['input_words'] for _ in batch], axis=0), dtype=torch.long)
input_chars = torch.tensor(np.stack([_['input_chars'] for _ in batch], axis=0), dtype=torch.long)
syn_tree_indices = torch.tensor(np.stack([_['syn_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
syn_rule_nl_left = torch.tensor(np.stack([_['syn_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
syn_rule_nl_right = torch.tensor(np.stack([_['syn_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
rea_tree_indices = torch.tensor(np.stack([_['rea_tree_indices'] for _ in batch], axis=0), dtype=torch.long)
rea_rule_nl_left = torch.tensor(np.stack([_['rea_rule_nl_left'] for _ in batch], axis=0), dtype=torch.long)
rea_rule_nl_right = torch.tensor(np.stack([_['rea_rule_nl_right'] for _ in batch], axis=0), dtype=torch.long)
class_mask = torch.tensor(np.stack([_['class_mask'] for _ in batch], axis=0), dtype=torch.float32)
query_paths = torch.tensor(np.stack([_['query_paths'] for _ in batch], axis=0), dtype=torch.long)
labels = torch.tensor(np.stack([_['labels'] for _ in batch], axis=0), dtype=torch.long)
parent_matrix = torch.tensor(np.stack([_['parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
syn_parent_matrix = torch.tensor(np.stack([_['syn_parent_matrix'] for _ in batch], axis=0), dtype=torch.float)
path_lens = torch.tensor(np.stack([_['path_lens'] for _ in batch], axis=0), dtype=torch.long)
syn_path_lens = torch.tensor(np.stack([_['syn_path_lens'] for _ in batch], axis=0), dtype=torch.long)
return_dict = {'input_words': input_words,
'input_chars': input_chars,
'syn_tree_indices': syn_tree_indices,
'syn_rule_nl_left': syn_rule_nl_left,
'syn_rule_nl_right': syn_rule_nl_right,
'rea_tree_indices': rea_tree_indices,
'rea_rule_nl_left': rea_rule_nl_left,
'rea_rule_nl_right': rea_rule_nl_right,
'class_mask': class_mask,
'query_paths': query_paths,
'labels': labels,
'parent_matrix': parent_matrix,
'syn_parent_matrix': syn_parent_matrix,
'path_lens': path_lens,
'syn_path_lens': syn_path_lens}
return return_dict
train_dataset=UsptoDataset("train_trees.csv")
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, collate_fn=UsptoDataset.collate_fn)
当我按如下方式使用数据加载器时,它会在每个时期为我提供不同数量的批次:
epoch_steps = len(train_loader)
for e in range(epochs):
for j, batch_data in enumerate(train_loader):
step = e * epoch_steps + j
日志显示第一个epoch只有5个batch,第二个epoch有3个batch,第三个epoch有5个batch,以此类推。
1 Config:
2 Namespace(batch_size_per_gpu=4, epochs=400, eval_every_epoch=1, hidden_size=128, keep=10, log_every_step=1, lr=0.001, new_model=False, save_dir='saved_model/', workers=1)
3 2021-01-06 15:33:17,909 - __main__ - WARNING - Checkpoints not found in dir saved_model/, creating a new model.
4 2021-01-06 15:33:18,340 - __main__ - INFO - Step: 0, Loss: 5.4213, Rule acc: 0.1388
5 2021-01-06 15:33:18,686 - __main__ - INFO - Step: 1, Loss: 4.884, Rule acc: 0.542
6 2021-01-06 15:33:18,941 - __main__ - INFO - Step: 2, Loss: 4.6205, Rule acc: 0.6122
7 2021-01-06 15:33:19,174 - __main__ - INFO - Step: 3, Loss: 4.4442, Rule acc: 0.61
8 2021-01-06 15:33:19,424 - __main__ - INFO - Step: 4, Loss: 4.3033, Rule acc: 0.6211
9 2021-01-06 15:33:20,684 - __main__ - INFO - Dev Loss: 3.5034, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5970844200679234, in epoch 0
10 2021-01-06 15:33:22,203 - __main__ - INFO - Test Loss: 3.4878, Test Sample Acc: 0.0, Test Rule Acc: 0.6470248053471247
11 2021-01-06 15:33:22,394 - __main__ - INFO - Found better dev sample accuracy 0.0 in epoch 0
12 2021-01-06 15:33:22,803 - __main__ - INFO - Step: 10002, Loss: 3.6232, Rule acc: 0.6555
13 2021-01-06 15:33:23,046 - __main__ - INFO - Step: 10003, Loss: 3.53, Rule acc: 0.6442
14 2021-01-06 15:33:23,286 - __main__ - INFO - Step: 10004, Loss: 3.4907, Rule acc: 0.6498
15 2021-01-06 15:33:24,617 - __main__ - INFO - Dev Loss: 3.3081, Dev Sample Acc: 0.0, Dev Rule Acc: 0.5980878387178693, in epoch 1
16 2021-01-06 15:33:26,215 - __main__ - INFO - Test Loss: 3.2859, Test Sample Acc: 0.0, Test Rule Acc: 0.6466992994149526
17 2021-01-06 15:33:26,857 - __main__ - INFO - Step: 20004, Loss: 3.3965, Rule acc: 0.6493
18 2021-01-06 15:33:27,093 - __main__ - INFO - Step: 20005, Loss: 3.3797, Rule acc: 0.6314
19 2021-01-06 15:33:27,353 - __main__ - INFO - Step: 20006, Loss: 3.3959, Rule acc: 0.5727
20 2021-01-06 15:33:27,609 - __main__ - INFO - Step: 20007, Loss: 3.3632, Rule acc: 0.6279
21 2021-01-06 15:33:27,837 - __main__ - INFO - Step: 20008, Loss: 3.3331, Rule acc: 0.6158
22 2021-01-06 15:33:29,122 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 2
23 2021-01-06 15:33:30,689 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
24 2021-01-06 15:33:32,143 - __main__ - INFO - Dev Loss: 3.0911, Dev Sample Acc: 0.0, Dev Rule Acc: 0.6016287207603455, in epoch 3
25 2021-01-06 15:33:33,765 - __main__ - INFO - Test Loss: 3.0651, Test Sample Acc: 0.0, Test Rule Acc: 0.6531393428643545
26 2021-01-06 15:33:34,359 - __main__ - INFO - Step: 40008, Loss: 3.108, Rule acc: 0.6816
27 2021-01-06 15:33:34,604 - __main__ - INFO - Step: 40009, Loss: 3.0756, Rule acc: 0.6732
28 2021-01-06 15:33:35,823 - __main__ - INFO - Dev Loss: 3.0419, Dev Sample Acc: 0.0, Dev Rule Acc: 0.613776079245976, in epoch 4
len(train_loader.dataset)
仅供参考,的值batch_size
和len(train_loader)
分别是和40008
,这正是我所期望的。所以它是如此令人困惑,以至于 using只给了我几个批次,例如or (是预期的)。4
10002
enumerate
3
5
10002
解决方案
我不确定您的代码有什么问题。据我所知,您要做的collate_fn
是从批处理中收集和堆叠相同特征类型的数据。就像是:
您正在使用input_words
, input_chars
, syn_tree_indices
, syn_rule_nl_left
, syn_rule_nl_left
, syn_rule_nl_right
, rea_tree_indices
, rea_tree_indices
, rea_rule_nl_left
, rea_rule_nl_right
, class_mask
, query_paths
, labels
, parent_matrix
, syn_parent_matrix
, path_lens
, 和syn_path_lens
作为键。a
在我的示例中,我们将仅使用、b
、c
和来保持简单d
。
__getitem__
将从您的数据集中返回一个数据点。在您的情况下,它将是一个字典:{'a': ..., 'b': ..., 'c': ..., 'd': ...}
。collate_fn
:是返回数据时数据集和数据加载器之间的中间层。它需要一个批处理元素列表(已用 逐个收集的元素__getitem__
)。您要在此处返回的是经过整理的批次。将转换[{'a': ..., 'b': ..., 'c': ..., 'd': ...}, ...]
为{'a': [...], 'b': [...], 'c': [...], 'd': [...]}
. 其中 key'a'
将包含该a
功能的所有数据...
现在您可能不知道这种简单的整理类型,您实际上并不需要collate_fn
. 我相信元组和字典是由 PyTorch 数据加载器自动处理的。这意味着如果您从中返回字典__getitem__
,您的数据加载器将通过键自动整理。
在这里,仍然是我们的最小示例:
class D(Dataset):
def __init__(self):
super(D, self).__init__()
self.a = [1,11,111,1111,11111]
self.b = [2,22,222,2222,22222]
self.c = [3,33,333,3333,33333]
self.d = [4,44,444,4444,44444]
def __getitem__(self, i):
return {
'a': self.a[i],
'b': self.b[i],
'c': self.c[i],
'd': self.d[i]
}
def __len__(self):
return len(self.a)
正如您在下面的print中看到的那样,数据是通过键收集的。
>>> ds = D()
>>> dl = DataLoader(ds, batch_size=2, shuffle=True)
>>> for i, x in enumerate(dl):
>>> print(i, x)
0 {'a': tensor([11, 1111]), 'b': tensor([22, 2222]), 'c': tensor([33, 3333]), 'd': tensor([44, 4444])}
1 {'a': tensor([1, 11111]), 'b': tensor([2, 22222]), 'c': tensor([3, 33333]), 'd': tensor([4, 44444])}
2 {'a': tensor([111]), 'b': tensor([222]), 'c': tensor([333]), 'd': tensor([444])}
提供collate_fn
参数将删除此自动整理。
推荐阅读
- python - 如何在字符串中对管道进行单转义:意外的 str.replace 行为
- kotlin - 如何在 rxjava 中对从 Network api 返回的数据进行分组
- c# - 在不使用 Scaffold DbContext 的情况下将表添加到实体
- python - 如何通过套接字从树莓派接收视频流到我的桌面?
- sql - 组合两个返回不同行数的 SQL 查询
- java - org.springframework.util.InvalidMimeTypeException:无效的 mime 类型“application:json;charset=utf8”:不包含 '/'
- node.js - 有没有更好的方法将类参数传递给外部函数?
- symfony - 我想在 add 方法中升级表属性,当我在表中添加某些内容时,我想减少另一个表中的属性
- c++ - LinkedList RemovePosition 删除错误的元素
- oracle - 将转储文件导入到 sql developer 时出现问题