首页 > 解决方案 > 使用多处理时 OpenAI CLIP (PyTorch) 挂起

问题描述

在 multiprocessing.Process 中运行 CLIP 时,系统一到达预处理步骤就会挂起(实际上我假设这实际上是任何炬管操作)。一个最小的例子:

import torch
import clip
from PIL import Image
import multiprocessing as mp
import sys

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def infer():
  print("PREPROCESSING")
  sys.stdout.flush()
  image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)

  print("TOKENIZING")
  sys.stdout.flush()
  text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

  print("INFERRING")
  sys.stdout.flush()
  with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

    print(f'{image_features.shape}')
    print(f'{text_features.shape}')
    sys.stdout.flush()

    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

  print(
      f"Label probs: {probs}")  # prints: [[0.9927937  0.00421068 0.00299572]]
  sys.stdout.flush()


p = mp.Process(target=infer, daemon=True)
p.start()
p.join()

此示例等效于CLIP 自述文件中的入门使用方法示例,但将模型推理包装在 Process 中。

这段代码的输出是:

/home/amol/code/soot/debugging/clip_tests/env/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check th
at you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0                                                  
PREPROCESSING   

之后它挂起。有什么建议么?

编辑 我通过深入研究 CLIP 以查看问题所在,设法使这个变得更小。我去那里:

import os                                                                                                                                                                 
import urllib                                                                                                                                                             
                                                                                                                                                                          
from tqdm import tqdm                                                                                                                                                     
                                                                                                                                                                          
import torch                                                                                                                                                              
import clip                                                                                                                                                               
from PIL import Image                                                                                                                                                     
import multiprocessing as mp                                                                                                                                              
                                                                                                                                                                          
                                                                                                                                                                          
def _download(url, root=os.path.expanduser("~/.cache/clip")):                        
  os.makedirs(root, exist_ok=True)        
  download_target = os.path.join(root, os.path.basename(url))                        
                                                                                                                                                                          
  with urllib.request.urlopen(url) as source, open(download_target,                  
                                                   "wb") as output:                  
    with tqdm(total=int(source.info().get("Content-Length")),                        
              ncols=80,
              unit='iB',
              unit_scale=True) as loop:   
      while True:                         
        buffer = source.read(8192)        
        if not buffer:                                                                                                                                                    
          break                                                                      

        output.write(buffer)
        loop.update(len(buffer))

  return download_target


def load(device="cpu"):
  model_path = _download(
      "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"  # noqa
  )                                       
  model = torch.jit.load(model_path, map_location="cpu").eval()                      
  model = clip.model.build_model(model.state_dict()).to(device)
  
load()                                    


def test():                               
  print("GETTING IMAGE")
  im = Image.open("CLIP.png")
  print("CONVERTING")
  im = im.convert('RGB')
  print("MADE TENSOR")
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(im.tobytes()))                
  print("VIEW")                           
  img = img.view(im.size[1], im.size[0], len(im.getbands()))                         
  print("PERMUTING")                      
  img = img.permute((2, 0, 1)).contiguous()                                          
  print("DIV")                            
  img = img.float().div(255)
  print("UNSQUEEZE")                      
  img = img.unsqueeze(0)


p = mp.Process(target=test, daemon=True)
p.start()                                 
p.join()

请注意,调用 load() 或 download() 创建的任何内容都不会被实际使用。此外,如果我注释掉 build_model() 行,一切正常。什么是导致问题的 clip.model.build_model() 做什么?

标签: pythonpytorchmultiprocessing

解决方案


推荐阅读