首页 > 解决方案 > 使用许多训练数据进行模型不学习

问题描述

所以我尝试使用卷积神经网络制作着色器。为此,我从互联网上搜索了一些模型,我认为它们可以工作,因为我已经在一张图像上测试了它们,输入灰度并使模型记住了图像。它看起来不像原始图像,但灰度变为彩色,我认为至少可以。

因此,我尝试在模型和 10 到 30 左右的一些时期收集大量数据,模型在处理时输出黑白图像。我不确定为什么会这样。请帮忙。

from PIL import Image                                                            
import numpy                                                                     
import glob
import os 
import cv2
import numpy as np
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from skimage import io, color
#Neural Network model 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# class lanscapes():


if torch.cuda.is_available():
    device=torch.device("cuda:0")
    print("running on the GPU")
else:
    device=torch.device("cpu")
    print("running on the CPU")

label = 'training'
IMG_SIZE=800
size=5
REBUILD_DATA = True
class landscapes():
    
# imagePath = glob.glob(imageFolderPath + '/*.JPG') 
    

    def make_training_data(self):
        global training_data
        training_data=[]

        for f in tqdm(os.listdir(label)):
            path = os.path.join(label, f)
            img = cv2.imread(path)
            img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
            img=cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            training_data.append(img)
                        
        training_data=numpy.asarray(training_data)
        f, axarr = plt.subplots(2)
    
        axarr[0].imshow(training_data[0])
        axarr[1].imshow(training_data[1])
        
    def rgb2lab(self,training_data):
        global L
        global ab
        L=np.zeros((size,800,800))
        ab=np.zeros((size,800,800,2))
        for i in tqdm(range(size)):
            L[i]= color.rgb2lab(1.0/255*training_data[i])[:,:,0] #grayscale image
            ab[i] = color.rgb2lab(1.0/255*training_data[i])[:,:,1:]
            ab=ab/128
            # L=L/100
            
        L=torch.Tensor([L]).reshape(size,1,800,800)
        ab=torch.Tensor([ab]).reshape(size,2,800,800)
        return L,ab
    
if REBUILD_DATA:
   td=landscapes()  
   td.make_training_data()
    
# training_data=np.load("training_data.npy",allow_pickle=True)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1=nn.Conv2d(1,32, 3,stride=(1, 1), padding=(1, 1)) #conv layer with 3 kernel size
        self.conv2=nn.Conv2d(32,32, 3,stride=(2, 2), padding=(1, 1)) 
        self.conv3=nn.Conv2d(32,64, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv3a=nn.Conv2d(64,64, 3,stride=(2, 2), padding=(1, 1))
        self.conv4=nn.Conv2d(64,128, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv5=nn.Conv2d(128,128, 3,stride=(2, 2), padding=(1, 1)) 
        self.conv6=nn.Conv2d(128,256, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv7=nn.Conv2d(256,128, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv8=nn.Conv2d(128,64, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv9=nn.Conv2d(64,32, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv10=nn.Conv2d(32,16, 3,stride=(1, 1), padding=(1, 1)) 
        self.conv11=nn.Conv2d(16,2, 3,stride=(1, 1), padding=(1, 1))
        self.up=nn.Upsample(scale_factor=2)
       
       
        
    def forward(self,x):
        x=F.relu(self.conv1(x))
        x=F.relu(self.conv2(x))
        x=F.relu(self.conv3(x))
        x=F.relu(self.conv3a(x))
        x=F.relu(self.conv4(x))
        x=F.relu(self.conv5(x))
        x=F.relu(self.conv6(x))
        x=F.relu(self.conv7(x))
        x=F.relu(self.conv8(x))
        x=self.up(x)
        x=F.relu(self.conv9(x))
        x=self.up(x)
        x=F.relu(self.conv10(x))
        x=torch.tanh(self.conv11(x))
        x=self.up(x)
       
        # x=torch.tanh(self.conv8(x))
        # return torch.tanh(y,dim=1) #output
        return x
# net=Net()
net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_function=nn.MSELoss()
Lb=np.zeros((1,800,800))

def fwd_pass(X,y,train=False):#checking if youre training 
     
    if train:
        net.zero_grad()
    outputs=net(X)
    loss=loss_function(outputs,y)
    
    if train:
        loss.backward()
        optimizer.step()
    return loss

def train():

    EPOCHS=10

    for epoch in range(EPOCHS):
            for i in tqdm(range(size)):
                # Lb=L[i].reshape[1,1,800,800]
                loss=fwd_pass(L[i].view(1,1,800,800).to(device),ab[i].view(1,2,800,800).to(device),train=True)
            print(f"Epoch: {epoch}. Loss: {loss}")

def test():
    X="datasets_298806_1217826_00000000_(2).jpg"
    path=X
   
    path = cv2.imread(path)
    patha=cv2.cvtColor(path, cv2.COLOR_BGR2RGB)
    path = cv2.resize(patha, (800,800))
    gs= color.rgb2lab(1.0/255*path)[:,:,0] #transform test to input
    abt = color.rgb2lab(1.0/255*path)[:,:,1:]
    # gs/100
    abt=abt/128
    gs=torch.Tensor([gs]).reshape(1,1,800,800)
    abt=torch.Tensor([abt]).reshape(1,2,800,800)
    gs=gs.cuda()
    abt=abt.cuda()
    losst=fwd_pass(gs.to(device),abt.to(device),train=False)
    print(f" Loss: {losst}")
    out=net(gs)
    out=out*128
   
    canvas=np.zeros((800,800,3))
    gs=gs.reshape(800,800)
    gs=gs.cpu().detach().numpy()
    # out=out.cpu().data.numpy()
    canvas[:,:,0] = gs
    out=out.cpu().detach().numpy()
    out=out.reshape(800,800,2)
    canvas[:,:,1:] = out
    canvas=color.lab2rgb(canvas)
    
    f, axarr = plt.subplots(2)
    axarr[0].imshow(canvas)
    # axarr[1].imshow(X)
    axarr[1].imshow(patha)
    plt.imsave('model.jpg', canvas)
    plt.imsave('orig.jpg', path)

        
       
                 

L,ab=td.rgb2lab(training_data)
L=L.cuda()
ab=ab.cuda()
train()
test()
# output=net(L)

# f, axarr = plt.subplots(2)
# axarr[0].imshow(L[1])
# canvas=np.zeros((800,800,3))
# # axarr[1].imshow(ab[1])
# canvas[:,:,0] = L[1].reshape(800,800)
# canvas[:,:,1:] = ab[1]
# canvas=color.lab2rgb(canvas)
# axarr[1].imshow(canvas)

标签: image-processingdeep-learningconv-neural-network

解决方案


我认为主要问题是您应该使用批处理,或者累积梯度,或者减少 LR。

另外,我建议绘制你的火车损失。如果 train.loss 正常减少,则添加验证集并绘制 val.loss。您可能需要添加 BatchNorm、增强、减小图像大小、调整优化器...

您还可以在本文中找到一些灵感或找到自动编码器着色的良好实现。


推荐阅读