zip - 如何使用 Torch 将压缩的检查点文件加载到 Tacotron2 模型
问题描述
我正在尝试在 Tacotron2 上应用 MBP(基于幅度的修剪)算法。我成功了,我得到了许多基于 gzip 格式 K% 稀疏性的检查点。
之后,我尝试通过将 gzip 格式的新检查点加载到我的模型来运行模型推理(它使用 Pytorch)
问题是:
tacotron2 = torch.load("path.tar.gz") #dont work .
因此,请了解如何将 gzip chekpoint 加载到 torch 。
解决方案
我确实通过以下方式找到了解决方案:
import pickle
import tarfile
from torch.serialization import _load, _open_zipfile_reader
#Function to load file
def torch_load_targz(filep_ath):
tar = tarfile.open(filep_ath, "r:gz")
member = tar.getmembers()[0]
with tar.extractfile(member) as untar:
with _open_zipfile_reader(untar) as zipfile:
torch_loaded = _load(zipfile, None, pickle)
return torch_loaded
# Try the function
tacotron2 = torch_load_targz('path.tar.gz')
tacotron2.eval()