python - 使用多处理时 OpenAI CLIP (PyTorch) 挂起
问题描述
在 multiprocessing.Process 中运行 CLIP 时,系统一到达预处理步骤就会挂起(实际上我假设这实际上是任何炬管操作)。一个最小的例子:
import torch
import clip
from PIL import Image
import multiprocessing as mp
import sys
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def infer():
print("PREPROCESSING")
sys.stdout.flush()
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
print("TOKENIZING")
sys.stdout.flush()
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
print("INFERRING")
sys.stdout.flush()
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
print(f'{image_features.shape}')
print(f'{text_features.shape}')
sys.stdout.flush()
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print(
f"Label probs: {probs}") # prints: [[0.9927937 0.00421068 0.00299572]]
sys.stdout.flush()
p = mp.Process(target=infer, daemon=True)
p.start()
p.join()
此示例等效于CLIP 自述文件中的入门使用方法示例,但将模型推理包装在 Process 中。
这段代码的输出是:
/home/amol/code/soot/debugging/clip_tests/env/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check th
at you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
return torch._C._cuda_getDeviceCount() > 0
PREPROCESSING
之后它挂起。有什么建议么?
编辑 我通过深入研究 CLIP 以查看问题所在,设法使这个变得更小。我去那里:
import os
import urllib
from tqdm import tqdm
import torch
import clip
from PIL import Image
import multiprocessing as mp
def _download(url, root=os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
download_target = os.path.join(root, os.path.basename(url))
with urllib.request.urlopen(url) as source, open(download_target,
"wb") as output:
with tqdm(total=int(source.info().get("Content-Length")),
ncols=80,
unit='iB',
unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
return download_target
def load(device="cpu"):
model_path = _download(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" # noqa
)
model = torch.jit.load(model_path, map_location="cpu").eval()
model = clip.model.build_model(model.state_dict()).to(device)
load()
def test():
print("GETTING IMAGE")
im = Image.open("CLIP.png")
print("CONVERTING")
im = im.convert('RGB')
print("MADE TENSOR")
img = torch.ByteTensor(torch.ByteStorage.from_buffer(im.tobytes()))
print("VIEW")
img = img.view(im.size[1], im.size[0], len(im.getbands()))
print("PERMUTING")
img = img.permute((2, 0, 1)).contiguous()
print("DIV")
img = img.float().div(255)
print("UNSQUEEZE")
img = img.unsqueeze(0)
p = mp.Process(target=test, daemon=True)
p.start()
p.join()
请注意,调用 load() 或 download() 创建的任何内容都不会被实际使用。此外,如果我注释掉 build_model() 行,一切正常。什么是导致问题的 clip.model.build_model() 做什么?
解决方案
推荐阅读
- laravel - 在 Laravel 中搜索
- unit-testing - 带有任何参数的 Mockito when() 无法按预期工作
- javascript - 由于切换功能,div 被隐藏
- grails - 在 Grails 的 json 视图中访问控制器参数
- angular - 仅查询参数更改时重新加载 Angular 组件
- mongodb - Mongodb每月创建自动创建索引任务
- docker - 在 docker 中使用 gitlab 私有 repo 作为 golang 依赖项
- javascript - 谷歌地图,标记位置正确但中心不正确
- c# - _() 是什么意思?
- python - 在python中以同质方式表示来自两个列表的单词之间的单词相似度