首页 > 解决方案 > 如何使用 Torch 将压缩的检查点文件加载到 Tacotron2 模型

问题描述

我正在尝试在 Tacotron2 上应用 MBP(基于幅度的修剪)算法。我成功了,我得到了许多基于 gzip 格式 K% 稀疏性的检查点。

之后,我尝试通过将 gzip 格式的新检查点加载到我的模型来运行模型推理(它使用 Pytorch)

问题是:

tacotron2 = torch.load("path.tar.gz") #dont work .

因此,请了解如何将 gzip chekpoint 加载到 torch 。

标签: zipgziptorchrar

解决方案


我确实通过以下方式找到了解决方案:

import pickle

import tarfile

from torch.serialization import _load, _open_zipfile_reader

#Function to load file
def torch_load_targz(filep_ath):
  tar = tarfile.open(filep_ath, "r:gz")
  member = tar.getmembers()[0]
  with tar.extractfile(member) as untar:
    with _open_zipfile_reader(untar) as zipfile:
        torch_loaded = _load(zipfile, None, pickle)
  return torch_loaded

# Try the function
tacotron2 = torch_load_targz('path.tar.gz')
tacotron2.eval()

推荐阅读