pytorch - Pytorch CNN 上的 Optuna
问题描述
class ConvolutionalNetwork(nn.Module):
def __init__(self, in_features, trial):
super().__init__()
self.in_features = in_features
self.trial = trial
# this computes no of features outputted by 2 conv layers
c1 = int(((self.in_features - 2)) / 64) # this is to account for the loss due to conversion to int type
c2 = int((c1 - 2) / 64)
self.n_conv = int(c2 * 16)
# self.n_conv = int((( ( (self.in_features - 2)/4 ) - 2 )/4 ) * 16)
num_filters1 = trial.suggest_int("num_filters1",16,64,step=16)
num_filters2 = trial.suggest_int("num_filters2",16,64,step=16)
#num_filters = 16
kernel_size = trial.suggest_int('kernel_size', 2, 7)
self.conv1 = nn.Conv1d(1, num_filters1, kernel_size, 1)
self.conv1_bn = nn.BatchNorm1d(num_filters1)
self.conv2 = nn.Conv1d(num_filters1, num_filters2, kernel_size, 1)
self.conv2_bn = nn.BatchNorm1d(num_filters2)
#Add in trial range for dropout to determine optimal dropout value
self.dp = nn.Dropout(trial.suggest_uniform('dropout_rate',0,1.0))
self.fc3 = nn.Linear(self.n_conv, 2)
我尝试在 Optuna 参数调整试验中添加过滤器 1 和 2 的数量,如下所示,但出现以下错误。试验 0 设法通过,但试验 1 没有。我对批量大小进行了调整,但使用不同的推荐批量大小会出现相同的错误格式。
num_filters1 = trial.suggest_int("num_filters1",16,64,step=16)
num_filters2 = trial.suggest_int("num_filters2",16,64,step=16)
[I 2020-12-07 09:22:53,512] Trial 0 finished with value: 0.6597743630409241 and parameters: {'num_filters1': 16, 'num_filters2': 16, 'kernel_size': 7, 'dropout_rate': 0.5225509182876455, 'optimizer': 'SGD', 'lr': 0.00020958259416674875, 'Weight_decay': 3.8111213220887506e-08}. Best is trial 0 with value: 0.6597743630409241.
[W 2020-12-07 09:22:55,502] Trial 1 failed because of the following error: ValueError('Expected input batch_size (8000) to match target batch_size (2000).')
解决方案
移除了 num_filter2 的 Optuna 自动调整,因为在最后一个过滤器层之后会有更多的链效应。
保留 num_filter1 用于 Optuna 自动调整,模型能够顺利运行。
谢谢!
推荐阅读
- c++ - C++ - 如何将空格键转换为 cmd 输入的破折号
- mysql - 在表中插入自动增量和当前日期时间
- python - 从外部 txt.file 读取/写入时忽略 IF 语句
- html - 尝试 stroke-dasharray 时,包含字母的 CSS、SVG 图像不起作用
- c - yacc 和 lex 中的解析器
- javascript - 如何从 JavaScript 中的嵌套数组中删除元素?
- c# - 我在尝试访问 HTML 数据时总是得到 NullReference |C# Web Scraping | Html 敏捷包
- swift - UNNotificationSound 无法在 macOS 上播放。相同的文件适用于 iOS
- python - UnsatisfiableError - 康达
- python - 在单独的文件中创建我自己的自定义方法