machine-learning - Band RNN 的计算精度
问题描述
所以我想弄清楚如何计算 BandRNN 的准确性。
BandRnn 是一个对角RNN 模型,每个神经元的连接数不同。例如: 这里 C 是每个神经元的连接数。
我目前的模型训练如下:
model = ModelLSTM(m, k).to(device)
model.train()
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
best_test = 1e7
best_validation = 1e7
for ep in range(1, args.epochs + 1):
init_time = datetime.now()
processed = 0
step = 1
for batch_idx, (batch_x, batch_y, len_batch) in enumerate(train_loader):
batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
opt.zero_grad()
logits = model(batch_x)
loss = model.loss(logits, batch_y, len_batch)
acc = sum(logits == batch_y) * 1.0 / len(logits)
print(acc)
loss.backward()
if args.clip > 0:
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
opt.step()
processed += len(batch_x)
step += 1
print(" batch_idx {}\tLoss: {:.2f} ".format(batch_idx, loss))
print("Epoch {}, LR {:.5f} \tLoss: {:.2f} ".format(ep, opt.param_groups[0]['lr'], loss))
我的模型测试如下:
model.eval()
with torch.no_grad():
for batch_x, batch_y, len_batch in test_loader:
batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
logits = model(batch_x)
loss_test = model.loss(logits, batch_y, len_batch)
acc = sum(logits == batch_y) * 1.0 / len(logits)
for batch_x, batch_y, len_batch in val_loader:
batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(device), len_batch.to(device)
logits = model(batch_x)
loss_val = model.loss(logits, batch_y, len_batch)
if loss_val < best_validation:
best_validation = loss_val.item()
best_test = loss_test.item()
print()
print("Val: Loss: {:.2f}\tBest: {:.2f}".format(loss_val, best_validation))
print("Test: Loss: {:.2f}\tBest: {:.2f}".format(loss_test, best_test))
print()
model.train()
我正在努力思考一种计算该模型准确性的方法,我想收到一些关于这样做的建议。谢谢你。
解决方案
我相信您代码中的这一行已经在尝试计算准确性:
acc = sum(logits == batch_y) * 1.0 / len(logits)
尽管您可能希望在与标签进行比较之前对 logits 进行 argmax:
preds = logits.argmax(dim=-1)
acc = sum(preds == batch_y) * 1.0 / len(logits)
推荐阅读
- python - 如何用sympy获得偏微分结果的系数
- node.js - 有没有办法使用 Cheerio 库捕获/抓取整个表格而不是逐个单元格?
- javascript - 如何更改 material-ui 中选定 MenuItem 的文本颜色?
- angular - 角度 9 keyPress 事件
- azure - Azure 工作簿 - 自动刷新和导入/导出选项
- flutter - 如何从 Flutter 中的 AssetImage 获取图像文件名?
- php - 在另一个 isset 函数中调用 isset 函数
- html - 我使用 html 和 css 制作表单,但标签和文本框不垂直相等
- mercurial - 在 Mercurial 中更新多个提交消息
- javascript - Jquery - JQuery 可能设置带有污染数据的属性