首页 > 解决方案 > Pytorch如何增加批量大小

问题描述

我目前有一个 torch.Size([1, 3, 256, 224]) 的张量,但我需要它作为输入形状 [32, 3, 256, 224]。我正在实时捕获数据,因此数据加载器似乎不是一个好的选择。有没有简单的方法来获取 32 个大小的 torch.Size([1, 3, 256, 224]) 并将它们组合起来以创建 1 个大小为 [32, 3, 256, 224] 的张量?

标签: pythonnumpyopencvpytorchonnx

解决方案


您很可能使用 jit 模型,并且批量大小必须与模型训练时的大小完全相同。

t = torch.rand(1, 3, 256, 224)
t.size() # torch.Size([1, 3, 256, 224])
t2= t.expand(32, -1,-1,-1)
t2.size() # torch.Size([32, 3, 256, 224])

扩展张量不会分配新的内存,而只会在现有张量上创建一个新视图,并且您会得到所需的内容。只有张量步幅发生了变化。


推荐阅读