pytorch - 如何在 Pytorch 的神经网络中实现一个时代相关参数?
问题描述
例如,假设我想在一层之后使用 Softmax 和温度,并且我还想在 5 个 epoch 之后降低温度。
for epoch in range(EPOCHS):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(torch_device), data[1].to(torch_device)
optimizer.zero_grad()
outputs = my_model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
my_model.adjust_softmax_temperature(epoch)
# ...
我应该如何实现 adjust_softmax_temperature 方法以使我的代码正常工作?
解决方案
据我所知,模型是在训练之前加载的,并且前向步骤通过已经定义的模型传递数据。我在这里错过了什么吗?
即使在定义模型之后,您也可以改变模型的状态。在这种情况下,我们有兴趣在my_model
训练期间更改 的属性。
以下是有关如何实现此功能的广泛概述。
首先,根据提供的纪元数公开一个函数来更新模型的温度:
class MyModel(nn.Module):
def __init__(self, t0):
self.t = t0
def forward(self, x):
# compute forward pass using self.t
pass
def adjust_softmax_temperature(self, epoch):
if epoch % 5:
self.t /= 10 # update policy of temperature
然后,在定义好模型后,您可以在每个 epoch 结束时调用更新函数:它将根据epoch
是否更新来决定self.t
。
my_model = MyModel(t0=100)
for epoch in range(EPOCHS):
for i, data in enumerate(trainloader, 0):
# train iteration
my_model.adjust_softmax_temperature(epoch)
推荐阅读
- java - 枚举的通用过滤器
- javascript - JS Chrome 插件 - 查找位于当前活动选项卡中的文本
- google-sheets - 在 Google 表格中过滤和比较数据
- javascript - React Router:使用 setState 更新的状态变量未正确传递给 Route
- ganymede - Ganymed SSH2 JAVA,提示:找不到命令
- java - 使用参数时找不到符号错误
- php - Symfony 3.4 和 Doctrine 事件 - 如果抛出异常,是否可以重定向用户?
- java - ROME API 从 RSS 源解析 CDATA 中的图像 URL
- javascript - 将客户端验证代码添加到 .jss 文件中
- sql - 试图计算投诉、原因、更正的数量