python - 如何定义在神经网络各层之间共享的单个参数?
问题描述
假设我们有以下全连接网络:
class FC(nn.Module):
def __init__(self, imgsz=28, num_classes=10):
super(FC, self).__init__()
self.imgsz = imgsz
self.fc1 = nn.Linear(imgsz*imgsz, 300)
self.fc2 = nn.Linear(300, 100)
self.fc3 = nn.Linear(100, num_classes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = x.view(-1, self.imgsz*self.imgsz)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
nn.Parameter()
我们如何使用或任何其他在所有层之间共享的方法来定义单个参数?
解决方案
如果我正确理解您的问题,这应该可以解决您的问题。只需在初始化时添加参数。
class FC(nn.Module):
def __init__(self, imgsz=28, num_classes=10, hidden_layer_size_1= 300, hidden_layer_size_2= 100):
super(FC, self).__init__()
self.imgsz = imgsz
self.fc1 = nn.Linear(imgsz*imgsz, hidden_layer_size_1)
self.fc2 = nn.Linear(hidden_layer_size_1, hidden_layer_size_2)
self.fc3 = nn.Linear(hidden_layer_size_2, num_classes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = x.view(-1, self.imgsz*self.imgsz)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
推荐阅读
- c - 为什么会出现浮点异常 8?
- python - Python tkinter:使用一个按钮选择一个txt文件,另一个按钮打开并读取内容到字符串
- android - 对于 Exoplayer 的 AdaptiveTrackSelection,我应该切换到具有多个比特率的单个轨道而不是具有单独比特率的四个轨道吗?
- php - Collection.php 第 1477 行中的 ErrorException:未定义的偏移量:0
- apache - Apache 反向代理后端身份验证
- python - 类型错误:>> 不支持的操作数类型:“builtin_function_or_method”和“_io.TextIOWrapper”。
- javascript - node - 将 jest 与 esm 包一起使用
- assembly - 使用立即数时汇编代码中的分段错误
- javascript - 如何通过php中的函数在浏览器中打印插入查询
- python - 按索引(列)编号选择熊猫数据框中的列