python - Python: TypeError: '(0, slice(None, None, None))' is an invalid key
问题描述
我正在尝试对已在图像数据集上训练的 EfficientNet 模型进行推理。应该产生输出的for loop
那个validation_fn
给出一个TypeError
. 似乎数据加载器没有加载任何东西,因为循环的输入是None
该数据集包含动物图像以及许多元数据特征。以下是代码的相关部分。
代码
def validate_fn(val_loader, model, params):
model.eval()
stream = tqdm(val_loader)
predictions = []
with torch.no_grad():
for i, (images, dense, target) in enumerate(stream, start=1): # **ERROR OCCURS HERE**
images = images.to(config.params['device'], non_blocking=True)
dense = dense.to(config.params['device'], non_blocking=True)
output = model(images, dense)
outputs = (torch.sigmoid(output).detach().cpu().numpy()*100).tolist()
predictions.extend(outputs)
predictions = np.concatenate(predictions)
gc.collect()
return predictions
def inference_fn(model_paths, dataloader, params):
final_preds = []
for i, path in enumerate(model_paths):
model = PetNet(params['model'], pretrained=False)
model.to(params['device'])
# model.load_state_dict(torch.load(path))
print(f"Getting predictions for model {i+1}")
preds = validate_fn(dataloader, model, params)
final_preds.append(preds)
final_preds = np.array(final_preds)
final_preds = np.mean(final_preds, axis=0)
return final_preds
class PetDataset(Dataset):
def __init__(self, images_filepaths, dense_features, targets, transform=None):
self.images_filepaths = images_filepaths
self.dense_features = dense_features
self.targets = targets
self.transform = transform
def __len__(self):
return len(self.images_filepaths)
def __getitem__(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transform is not None:
image = self.transform(image=image)['image']
dense = self.dense_features[idx, :]
return image, dense
class PetNet(nn.Module):
def __init__(self, model_name=config.params['model'], out_features=config.params['out_features'], inp_channels=config.params['inp_channels'],
pretrained=config.params['pretrained'], num_dense=len(config.params['dense_features'])):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=inp_channels)
n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(n_features, 128)
self.dropout = nn.Dropout(0.1)
self.out = nn.Linear(128 + 12, 1)
def forward(self, image, dense):
x = self.model(image)
x = self.dropout(x)
x = torch.cat([x, dense], dim=1)
output = self.out(x)
return output
model_paths = [
"/models/efficientnet-b1/tf_efficientnet_b1_ns_1_epoch_f1_19.407_rmse.pth"
]
df_test = pd.read_csv(config.TEST_FILE_PATH)
df_test['image_path'] = df_test['Id'].apply(lambda x: utils.return_filpath(x, folder = config.TEST_DIRECTORY))
sample = pd.read_csv(config.SAMPLE_FILE_PATH)
test_dataset = PetDataset(
images_filepaths = df_test['image_path'].values,
dense_features = df_test[config.params['dense_features']],
targets = sample["Pawpularity"].values,
transform = utils.get_valid_transforms()
)
test_loader = DataLoader(
test_dataset,
batch_size=config.params['batch_size'],
# batch_size = len(df_test),
shuffle=False,
num_workers=config.params['num_workers'], pin_memory=True
)
preds = inference_fn(model_paths, test_loader, config.params)
错误
TypeError Traceback (most recent call last)
<ipython-input-6-1de534f94594> in <module>()
138 )
139
--> 140 preds = inference_fn(model_paths, test_loader, config.params)
141
142 # sample = pd.read_csv(config.SAMPLE_FILE_PATH)
7 frames
<ipython-input-6-1de534f94594> in inference_fn(model_paths, dataloader, params)
60
61 print(f"Getting predictions for model {i+1}")
---> 62 preds = validate_fn(dataloader, model, params)
63 # final_preds.append(preds)
64
<ipython-input-6-1de534f94594> in validate_fn(val_loader, model, params)
7
8 with torch.no_grad():
----> 9 for i, (images, dense, target) in enumerate(stream, start=1):
10 images = images.to(config.params['device'], non_blocking=True)
11 dense = dense.to(config.params['device'], non_blocking=True)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-6-1de534f94594>", line 89, in __getitem__
dense = self.dense_features[idx, :]
File "/usr/local/lib/python3.7/dist-packages/pandas/core/frame.py", line 2906, in __getitem__
indexer = self.columns.get_loc(key)
File "/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py", line 2898, in get_loc
return self._engine.get_loc(casted_key)
File "pandas/_libs/index.pyx", line 70, in pandas._libs.index.IndexEngine.get_loc
File "pandas/_libs/index.pyx", line 75, in pandas._libs.index.IndexEngine.get_loc
TypeError: '(0, slice(None, None, None))' is an invalid key
解决方案
推荐阅读
- java - java.lang.IllegalStateException:在父级或祖先中找不到方法 onClickForgot(View)
- list - 如何忽略列表到列表映射中的成员?
- python - 在 python 中复制文件时出错(在路径中添加 r 前缀)
- python - 当我将目录更改为 /var/www/html 时,Python 代码没有运行
- c - I'm finding the frequency of each letter, although it always returns 0 for all the letters
- python - 我已经通过 python 安装了一个新库,但我在 pythoncharm 中找不到
- ruby - scrapoxy 中的响应标头中缺少“x-cache-proxyname”
- android - 如何在静态方法android中重用相同的对象
- elasticsearch - 弹性搜索词计数
- ejabberd - How to send message in muc group using ejabberd API?