pytorch - 关于 Train() 函数的设计
问题描述
我曾经看过以下神经网络的实现。我对model.train()
in 的功能感到困惑Train()
。在 CNN_ForecastNet 类中,我没有找到 train 的方法,
class CNN_ForecastNet(nn.Module):
def __init__(self):
super(CNN_ForecastNet,self).__init__()
self.conv1d = nn.Conv1d(3,64,kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.fc1 = nn.Linear(64*2,50)
self.fc2 = nn.Linear(50,1)
def forward(self,x):
x = self.conv1d(x)
x = self.relu(x)
x = x.view(-1)
#print('x size',x.size())
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = CNN_ForecastNet().to(device)
def Train():
running_loss = .0
model.train()
for idx, (inputs,labels) in enumerate(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
preds = model(inputs.float())
loss = criterion(preds,labels.float())
loss.backward()
optimizer.step()
running_loss += loss
train_loss = running_loss/len(train_loader)
train_losses.append(train_loss.detach().numpy())
print(f'train_loss {train_loss}')
解决方案
正如您在文档中所见,模块 train函数只是将模型中的标志设置为True
(您可以使用model.eval()
将标志设置为False
)。
这个标志被一些在 eval 模式下行为改变的层使用,最显着的是 dropout 和 batchnorm 层。
推荐阅读
- python - 如何在 tkinter 主循环旁边运行“While Loop” - Python
- python - 将python中生成的输出导出到excel文件
- ssms - SSMS 未将模板用于新存储过程
- c# - 修复不稳定的 nunit 测试
- spring-integration - 用于异步流和任务执行器的 Spring Integration MDC
- mysql - SUBSTRING_INDEX 用符号和字符对数字进行排序
- angular - 为什么 jasmin-karma-istanbul 报告中没有包含这个 if 语句
- java - MySQL 连接器 netbeans
- windows - 我无法在没有 Internet 的情况下运行 Windows 应用程序
- python - 如何在python中获得表之间的相关性?