首页 > 解决方案 > Pytorch 使用 HTTPS 加载自定义数据非常慢

问题描述

我尝试实现一个自定义数据加载器,它将发出一个 Web 请求并返回一个样本。我的程序的目的是看看这个想法是否会比原始数据加载器更快。我的网络服务器代码运行

srun -n24 --mem = 12g python web.py

然后将创建 24 个在集群中运行的“工作人员”。然后每个工作人员将其端口名写入文件,以使数据加载器知道他的存在。因此,当在训练循环中调用数据加载器时。数据加载器从文件中选择一个随机服务器,并向它们发送一个带有索引的 Web 请求。然后,Web 服务器将加载样本并进行扩充并通过 http 响应返回。在我看来,我认为它会比原始数据加载器更快,因为每个数据加载器工作人员都会向网络服务器发送请求并获取样本。因此,将数据分发到不同的服务器,以便它们更快地加载图像。

但是,当我使用 COCO 数据集与原始数据进行比较时。原始数据加载器需要 743.820 秒才能完成加载一个纪元,而我的自定义数据加载器需要 1503.26 秒才能完成。我无法弄清楚我的代码的哪一部分需要很长时间,所以我想寻求帮助。如果我的解释不好/不好,请告诉我。任何帮助表示赞赏。谢谢你。

以下是启动网络服务器的代码:

class PytorchDataHandler(BaseHTTPRequestHandler):
    def do_GET(self):

        self.send_response(200)

        self.end_headers()
        

        get_param = self.path
        get_param = parse_qs(urlparse(get_param).query)
        batch_list = [[],[]]
        c_batches = []
        index = get_param['index']
        if index :
            for data in index:
                result = imagenet_data[int(data)]
                batch_list[0].append(result[0])
                batch_list[1].append(result[1])
            c_batches.append(batch_list)
            torch.save(batch_list, self.wfile)



        else:
            write_log('Empty Parameter')


def main():
    sock = socket.socket(socket.AF_INET,socket.SOCK_DGRAM)
    hostname = socket.gethostname()
    n_hostname = hostname.split(".")

    # Bind to random port
    sock.bind(('0.0.0.0', 0))
    # Get Port Number
    PORT = int(sock.getsockname()[1])
    
    current_dir = os.getcwd()
    create_dir = os.path.join(current_dir, r'worker_file')

    #filename = create_dir + '/' +  str(n_hostname[0]) + '.cvl-tengig:' + str(PORT)
    filename = create_dir + '/' +  str(n_hostname[0])  + ':' + str(PORT)
    os.makedirs(os.path.dirname(filename), exist_ok=True)

    open_file = open(filename, 'w')
    open_file.write(str(n_hostname[0]) + ':' + str(PORT))   

    open_file.close()
    try :
        SERVER = HTTPServer(('', PORT), PytorchDataHandler)
        SERVER.serve_forever()
    except KeyboardInterrupt:
        print('Shutting down server, ^C')
        os.remove(filename)
        SERVER.socket.close()

if __name__ == '__main__':
    main()

自定义数据加载器的代码:

class DistData(Dataset):
    def __init__(self, data, transform = None):
        self.data = data
        # Get file path
        current_dir = os.getcwd()
        create_dir = os.path.join(current_dir, r'worker_file')

        # Get all item in file
        self.arr = os.listdir(create_dir)
        self.selected = []


    def __getitem__(self, index):
        # Select a random server
        
        random_server = random.choice(self.arr)
        
        # Remove selected server from the server list
        self.arr.remove(random_server)

        # Append selected server to the selected list
        self.selected.append(random_server)

        
        return self.post_request(index, random_server)

    def __len__(self):
        return len(self.data)

    def post_request(self, index, random_server):
        params = {'index': index}
        url = 'http://' + random_server + '/get'

    
        r = requests.get(url , params = params)
    

        print("Response Time : {:<10} , worker : {:<10} ".format(r.elapsed.total_seconds(), torch.utils.data.get_worker_info().id ))

        # Remove server from selected once there's a response
        self.selected.remove(random_server)
        # Add back to main server list after response
        self.arr.append(random_server)

        buffer = io.BytesIO(r.content)
        response = torch.load(buffer)

        return response

def train(net, device, trainloader, criterion, optimizer):
    for epoch in range(2):  # loop over the dataset multiple times
        running_loss = 0.0
        print('Epoch : {}'.format(epoch + 1))
        print('----------------------------')
        start_time = time.time()
        total_time = 0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data   
            print("Train: Time taken to load batch {} is {}".format(i+1,time.time() - start_time))
            total_time += time.time() - start_time
            start_time = time.time()
        print('Epoch : {} , Total Time Taken : {}'.format(epoch + 1, total_time))
    print('Finished Training')

    imagenet_data =torchvision.datasets.CocoCaptions('/db/shared/detection+classification/coco/train2017/' , 
                            '/db/shared/detection+classification/coco/annotations/captions_train2017.json')

    training_set = DistData(imagenet_data)
    trainloader = DataLoader(training_set, sampler = BatchSampler(RandomSampler(training_set), batch_size = 24, drop_last = False),
                num_workers = 4)


train(trainloader)

标签: pythonhttpspytorchbasehttpserver

解决方案


PyTorch 分发方式

首先,您应该熟悉torch.nn.parallel.DistributedDataParallel以查看如何以有效方式分发数据的示例。

您可以查看我的这个答案并附上丹尼尔的答案,以了解可能的策略。PyTorch介绍也是一个很好的资源。

简而言之,它的作用:

  • 主要工作人员加载数据(大批量)
  • 该批次均匀分布在其他工作人员中(与神经网络一起)
  • forward& 使用部分数据对每个工作人员进行后向传递
  • gradient从每个工人发送到main worker平均和优化器改进的地方model
  • model分布在工人之间(连同新批次)

在这种情况下,通过网络(或设备)发送三件事:

  • 一批数据(最好是大的)
  • 模型
  • 模型梯度

这使我们对您的解决方案提出了主要警告;应该尽量减少网络上的数据传输,因为它们真的很慢

您的分销方式

您正在对每个样品提出要求。它确实效率低下(由于网络传输),您应该追求的是请求整个批次的数据。此外,每个服务器都应该预先加载这些数据(假设提前四批),因此可以在需要时随时发送。

当您使用14工作人员host并且每个工作人员都向其他服务器发送数据请求时,其中一些可能会请求同一服务器。在这种情况下,每个人都必须等待另一个人。最好将每个工作人员指向每个服务器。

尽管如此,这种方法并不是很有效,因为模型host必须等待data服务器。

如果可能,您可以将整个COCO数据集拆分为多个工作人员。此外,在每个工人身上,应该有一个模块在做forwardbackward. 这将类似于上述PyTorch 的分发方式,除了batch跨设备传输。不利的一面是,培训的随机性会降低。

问题

但是,我对请求批次而不是样品有点不清楚。我是否将索引列表发送到服务器并返回批处理作为响应?

是的,看看torch.utils.data.DataLoaderbatch_sampler论点。

当您一开始没有收到任何请求时,您如何预加载数据?

您可以在一个请求中发送多个索引(例如三个批次的索引)。您在工作人员上准备第batch一个并将其发送到主机并在网络通信期间准备另一个(第二批)。因此,当您进行交流时,总会准备一批。


推荐阅读