首页 > 解决方案 > 如何为 torch.cat 初始化张量

问题描述

import torch

#Y_pred = ?

for xi in X_iter:
    y_pred = net(xi).argmax(dim=1)
    Y_pred = torch.cat([Y_pred, y_pred])

这个张量怎么初始化,有没有更好的写法?</p>

标签: pytorch

解决方案


你可以这样做:

Y_pred = torch.cat([net(xi).argmax(dim=1) for xi in X_iter])

推荐阅读