python - 获取具有多个元素的数组的真值以进行模糊变换
问题描述
我正在使用albumentations 将转换应用于Pytorch 模型,但出现此错误,我没有得到任何关于此错误的线索。我只知道这是由于正在应用的转换而发生的,但不确定这有什么问题。
ValueError: Traceback (most recent call last):
File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
File "<ipython-input-23-119ea6bc360e>", line 24, in __getitem__
image = self.transform(image)
File "/opt/conda/lib/python3.6/site-packages/albumentations/core/composition.py", line 164, in __call__
need_to_run = force_apply or random.random() < self.p
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
这是代码片段。数据加载器 getitem ( ) 方法:
image = cv2.imread(p_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = crop_image_from_gray(image)
image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
image = cv2.addWeighted ( image,4, cv2.GaussianBlur( image , (0,0) , 10) ,-4 ,128)
print(image.shape)
image = self.transform(image)
应用变换:
val_transform = albumentations.Compose([
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensor()
])
该类由以下人员调用:
valset = MyDataset(val_df, transform = val_transform)
解决方案
从官方专辑文档中,您可以将转换应用于图像
from PIL import Image
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from albumentations import Compose, RandomCrop, Normalize, HorizontalFlip, Resize
from albumentations.pytorch import ToTensor
class AlbumentationsDataset(Dataset):
"""__init__ and __len__ functions are the same as in TorchvisionDataset"""
def __init__(self, file_paths, labels, transform=None):
self.file_paths = file_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# Read an image with OpenCV
image = cv2.imread(file_path)
# By default OpenCV uses BGR color space for color images,
# so we need to convert the image to RGB color space.
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = crop_image_from_gray(image)
image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
image = cv2.addWeighted ( image,4, cv2.GaussianBlur( image , (0,0) , 10) ,-4 ,128)
image = Img.fromarray(image, mode='RGB')
if self.transform:
augmented = self.transform(image=np.array(image))
image = augmented['image']
image = np.transpose(image, (2, 0, 1))
return image, label
albumentations_transform = Compose([
Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
ToTensor()
])
albumentations_dataset = AlbumentationsDataset(
file_paths=['./images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg'],
labels=[1, 2, 3],
transform=albumentations_transform,
)
test_loader = DataLoader(dataset = albumentations_dataset, batch_size=4, drop_last=False, shuffle=False).
推荐阅读
- sql-server - 将全名解析为单独的列名称字段
- c# - Linq根据字符串数组变量搜索列表的属性
- mysql - 何在 WHERE 子句中使用 IF 编写查询?
- sql - 优化查询,从过去 2 年中选择 3 个薪水
- amazon-web-services - 使用 KMS 加密密钥时,无法使用 AWS Transfer for SFTP 读取或写入任何文件
- javascript - Electron APP显示未定义的document.body
- oracle-apex - Oracle Apex - 18.2 - 交互式网格 POPUP LOV 问题
- bash - 在未评估“同时读取行”之后读取(SHELL)
- r - R中第一个差分记录系列的离散积分
- python - Dask Dataframe 描述在分位数上返回 NaN 值的方法