首页 > 解决方案 > 关于 Train() 函数的设计

问题描述

我曾经看过以下神经网络的实现。我对model.train()in 的功能感到困惑Train()。在 CNN_ForecastNet 类中,我没有找到 train 的方法,

class CNN_ForecastNet(nn.Module):
    def __init__(self):
        super(CNN_ForecastNet,self).__init__()
        self.conv1d = nn.Conv1d(3,64,kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc1 = nn.Linear(64*2,50)
        self.fc2 = nn.Linear(50,1)
        
    def forward(self,x):
        x = self.conv1d(x)
        x = self.relu(x)
        x = x.view(-1)
        #print('x size',x.size())
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        
        return x

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN_ForecastNet().to(device)

def Train():
    
    running_loss = .0
    
    model.train()
    
    for idx, (inputs,labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        preds = model(inputs.float())
        loss = criterion(preds,labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss
        
    train_loss = running_loss/len(train_loader)
    train_losses.append(train_loss.detach().numpy())
    
    print(f'train_loss {train_loss}')

标签: pytorch

解决方案


正如您在文档中所见,模块 train函数只是将模型中的标志设置为True(您可以使用model.eval()将标志设置为False)。

这个标志被一些在 eval 模式下行为改变的层使用,最显着的是 dropout 和 batchnorm 层。


推荐阅读