graph - 训练 GNN 模型时出现数据类型错误
问题描述
我创建了一个包含 500 个图的数据集,如下所示:
data_list = []
ngraphs = 500
for i in range(ngraphs):
num_nodes = randint(10,500)
num_edges = randint(10,num_nodes*(num_nodes - 1))
f1 = np.random.randint(10, size=(num_nodes))
f2 = np.random.randint(10,20, size=(num_nodes))
f3 = np.random.randint(20,30, size=(num_nodes))
f_final = np.stack((f1,f2,f3), axis=1)
capital = 2*f1 + f2 - f3
f1_t = torch.from_numpy(f1)
f2_t = torch.from_numpy(f2)
f3_t = torch.from_numpy(f3)
capital_t = torch.from_numpy(capital)
capital_t = capital_t.type(torch.FloatTensor)
x = torch.from_numpy(f_final)
x = x.type(torch.FloatTensor)
#edge_index = torch.randint(low=0, high=num_nodes, size=(2, num_edges), dtype=torch.long)
edge_index = torch.randint(low=0, high=num_nodes, size=(num_edges,2), dtype=torch.long)
edge_attr = torch.randint(low=0, high=50, size=(num_edges,1), dtype=torch.long)
data = Data(x = x, edge_index = edge_index.t().contiguous(), y = capital_t, edge_attr=edge_attr )
data_list.append(data)
我正在尝试编写图神经网络模型:
import torch
from tqdm import tqdm
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm, ReLU
from torch_scatter import scatter
from torch_geometric.nn import GENConv, DeepGCNLayer
from torch_geometric.data import RandomNodeSampler
class DeeperGCN(torch.nn.Module):
def __init__(self, hidden_channels, num_layers):
super(DeeperGCN, self).__init__()
self.node_encoder = Linear(data.x.size(-1), hidden_channels)
self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)
self.layers = torch.nn.ModuleList()
for i in range(1, num_layers + 1):
conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',
t=1.0, learn_t=True, num_layers=2, norm='layer')
norm = LayerNorm(hidden_channels, elementwise_affine=True)
act = ReLU(inplace=True)
layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.5,
ckpt_grad=i % 3)
self.layers.append(layer)
self.lin = Linear(hidden_channels, 1)
def forward(self, x, edge_index, edge_attr):
x = self.node_encoder(x)
edge_attr = self.edge_encoder(edge_attr)
x = self.layers[0].conv(x, edge_index, edge_attr)
for layer in self.layers[1:]:
x = layer(x, edge_index, edge_attr)
x = self.layers[0].act(self.layers[0].norm(x))
return self.lin(x)
我有一个基本的训练功能:
def train():
model.train()
optimizer.zero_grad() # Clear gradients.
out = model(data.x, data.edge_index, data.edge_attr) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss
##Train the model
for epoch in range(1, 500):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
当我运行上述内容时,我不断收到以下错误:
RuntimeError:预期的标量类型 Float 但发现 Long
当我通过数据集制作时,我曾尝试修复所有数据类型问题。所以我不确定发生了什么。如果有人可以帮忙。
解决方案
推荐阅读
- google-apps-script - 应用脚本 body.appendTable 和 Paragraph.setIndentStart
- flutter - 听听 Flutter 中的 bloc 示例
- asp.net-web-api - 如何在 C# 中使用 power bi 客户端获取基于报表 ID 的报表页面列表?
- google-apps-script - Google表格:如何将电子表格文件中的单个表格保存为PDF,其中文件名基于某个单元格内容?
- javascript - 点击显示不同的画廊
- php - Woocommerce:如何在 functions.php 中使用 get_term_by :查找产品类别 ID
- azure-devops-server-2020 - Azure DevOps Server 2020 - 如何编辑经典发布管道的描述和标签?
- sql - 按同一字段分组和在哪里分组,出现空值
- ios - iOS - 处理事件:每当文本被剪辑时
- scala - Scala 在特征中交叉引用的文档源