首页 > 解决方案 > 如何获取运行 client_connected_cb 的 asyncio.start_server 创建的任务

问题描述

以下是文档: https ://docs.python.org/3/library/asyncio-stream.html

asyncio.start_server接受 a ,一旦客户端连接到服务器client_connected_cb,它就会在 a 中运行。Task我的目标是编写一个teardown函数,等待一切完成:所有读取器完成,所有写入器完成,服务器创建的所有任务都完成。为了做到这一点,我需要知道Task正在运行的服务器创建的 s client_connected_cb,但我不知道该怎么做?

有没有办法得到这个Task?否则,我不得不求助于每次client_connected_cb创建时间的技巧,我将其存储在某种字典中,然后等待通过轮询或其他方式清除该字典。

标签: python-3.xpython-asyncio

解决方案


如果您想要细粒度的控制,您可以创建一个自定义task_factory来捕获所有新任务并对其执行某些操作,例如为您想要在完成后执行某些操作的所有新任务添加回调。 Task.add_done_callback(callback)

def end(task: asyncio.Task):
    print(f'Task: {task.get_name()} Finished')
tasks = []
def my_task_factory(loop, coro):
    task = asyncio.Task(coro, loop=loop)
    if task.get_coro().__name__ == 'handle_client':
        task.add_done_callback(end)
    tasks.append(task)
    return task  # type: asyncio.Task

Loop = asyncio.get_event_loop()
Loop.set_task_factory(my_task_factory)

还有 2 种用于访问任务的高级方法。
current_task() all_tasks()
asyncio.current_task(loop=None) # Returns _asyncio.Task
asyncio.all_tasks(loop=None) # 返回一组 _asyncio.Task

import _asyncio
import asyncio
from asyncio import StreamReader, StreamWriter
class Server:
    def __init__(self,  bind_ip='localhost', bind_port=15555):
        self.Loop = asyncio.get_event_loop()
        self.bind_port = bind_port
        self.bind_ip = bind_ip

    def startup(self):
        self.Server_Task = self.Loop.create_task(self._run_server(), name='Server')
        self.Loop.run_forever()

    async def _run_server(self):
        self.Server = await asyncio.start_server(self.handle_client, host=self.bind_ip, port=self.bind_port)
        async with self.Server:
            print(f'Listening on {self.bind_ip}:{self.bind_port}')
            await self.Server.serve_forever()

    async def handle_client(self, reader: StreamReader, writer: StreamWriter):
        peer = writer.get_extra_info('peername')  # type: Tuple[str, int]
        try:
            task = asyncio.current_task(self.Loop) # type: _asyncio.Task
            tasks = asyncio.all_tasks(self.Loop) # type: Set # of _asyncio.Task
            print(f'Tasks Type:{type(tasks)}\n------')
            for i in tasks:
                print(f'Task: Type:{type(i)}\nTask: {i.get_name()}\nCoro: {i.get_coro().__name__}')
            print(f'------\nhandle_client: Type:{type(task)}\nTask: {task.get_name()}\nCoro: {task.get_coro().__name__}')
        except Exception as err:
            print(err)

        finally:
            writer.close()
            await writer.wait_closed()
            print(f'Closed: {peer}')


if __name__ == '__main__':
    TCPServer = Server()
    TCPServer.startup()
    print('End')

输出


推荐阅读