首页 > 解决方案 > 这个张量流代码的 Pytorch sub 是什么?

问题描述

在将这行代码从 Tensor Flow 转换为 Pytorch 时,我遇到了麻烦

datagen =  ImageDataGenerator(
  shear_range=0.2,
  zoom_range=0.2,
)
def read_img(filename, size, path):
    img = image.load_img(os.path.join(path, filename), target_size=size)
    #convert image to array
    img = img_to_array(img) / 255
    return img

进而

corona_df = final_train_data[final_train_data['Label_2_Virus_category'] == 'COVID-19']
with_corona_augmented = []

#create a function for augmentation
def augment(name):
    img = read_img(name, (255,255), train_img_dir)
    i = 0
    for batch in tqdm(datagen.flow(tf.expand_dims(img, 0), batch_size=32)):
        with_corona_augmented.append(tf.squeeze(batch).numpy())
        if i == 20:
            break
        i =i+1

#apply the function
corona_df['X_ray_image_name'].apply(augment)

我试着做

transform = transforms.Compose([transforms.Resize(255*255)
                               ])
train_loader = torch.utils.data.DataLoader(os.path.join(train_dir,corona_df),transform = transform,batch_size =32)
def read_img(path):
    img  = train_loader()
    img = np.asarray(img,dtype='int32')
    img = img/255
    return img

我尝试继续,但被错误弄糊涂了。我欢迎任何反馈。告诉我如果我错过了什么,即使是一个小小的建议也会起作用,谢谢!

标签: tensorflowpytorch

解决方案


您可以创建自定义数据集来读取图像。如果您有一个充满图像的目录,则可以使用 ImageFolder 默认数据集。否则,如果您有不同的文件夹位置,您可以编写自己的自定义数据集类。您可以查看此链接以获取自定义数据集。dataloader 所做的是,它会自动从您的数据集中获取数据,并根据您的数据集__getitem__功能读取图像并应用转换。所以你不需要任何花哨的东西来应用增强。

transform = transforms.Compose([ transforms.RandomAffine(20,shear=20,scale=(-0.2,0.2)),
                                 transforms.Resize(255*255)
                               ])

dataset = torchvision.datasets.ImageFolder(train_img_dir, transform=transform)
loader = torch.utils.data.DataLoader(dataset,batch_size =32,shuffle=True)

for batch in loader:
    output = model(batch)

推荐阅读