首页 > 技术文章 > 手撸一个ThreadPool线程池(源码简化版)

fantongxue 2022-02-10 15:56 原文

一,线程池定义

线程池由任务队列和工作线程组成,它可以重用线程来避免线程创建的开销,在任务过多时通过排队避免创建过多线程来减少系统资源消耗和竞争,确保任务有序完成。

更多介绍参考文章:https://www.imooc.com/article/51147
B站视频地址:https://space.bilibili.com/512437822

二,线程池执行原理

两个步骤

  1. 线程池创建任务线程
  2. 创建一个工作线程并启动工作线程(超过线程池的核心线程最大数则放在任务线程Runnable队列中等待执行)

【线程池的最大线程数指的就是工作线程数,而非任务线程】

工作线程的任务

执行自己的任务线程,自己的干完以后,再到任务线程队列中拿,拿到任务线程继续执行,直到队列清空为止

简单一句话描述工作线程

线程池核心线程最大数=5 ======== 有五个工人

线程池执行1000个线程 ========== 有1000个任务

刚开始,每个工人做一个任务,五个工人五个任务,那么剩下的995个任务就排好队,哪个工人的任务做完了,就取出来队里的第一个任务继续做,直到1000个任务全部做完。

三,开始撸它

1,核心代码

public class SimpleExecutor {
    //核心线程数量
    private int corePoolSize;
    //最大线程数量
    private int maxPoolSize;
    //保持活跃时间(非核心线程在获取任务的时候会可能阻塞)
    private long keepAliveTime;
    //线程池的运行状态
    private AtomicInteger State = new AtomicInteger();
    //worker的数量
    private AtomicInteger WorkerCount = new AtomicInteger();
    //线程池完成任务的数量
    private int finishedTaskCount = 0;
    //当前线程池状态
    private static final int RUNNING = 0;
    private static final int STOPPED = 1;
    //存放任务的阻塞队列
    private BlockingQueue<Runnable> taskQueue = new LinkedBlockingQueue<>();
    private HashSet<Worker> workers = new HashSet<>();

    public SimpleExecutor(int corePoolSize, int maxPoolSize, long keepAliveTime) {
        this.corePoolSize = corePoolSize;
        this.maxPoolSize = maxPoolSize;
        this.keepAliveTime = keepAliveTime;
    }

    /**
     * 得到已经完成的任务数量
     */
    public int getFinishedTaskCount() {
        int already = finishedTaskCount;
        for (Worker worker : workers) {
            already += worker.finishedTask;
        }
        return already;
    }


    /**
     * 核心方法
     */
    public void execute(Runnable task) {
        if (task == null) {
            throw new RuntimeException("任务为空,不接收空任务");
        }
        //检查当前的线程池是否为运行状态
        if (State.get() == RUNNING) {
            //接收任务 1,直接创建一个worker,当前的task作为worker的firstTask即可
            //2,worker数量已经达到最大限制了,就不能再次创建新的worker,需要往任务队列里面去放
            if (WorkerCount.get() < corePoolSize && addWorker(task, true)) {
                return;
            }
            if (WorkerCount.get() < maxPoolSize && addWorker(task, false)) {
                return;
            }
            //只有往任务队列里面放了
            if (State.get() == RUNNING) {
                if (!taskQueue.offer(task)) {
                    throw new RuntimeException("添加任务到队列失败");
                }
            }

        } else {
            //说明当前线程池处于停止状态了,拒绝任务
            throw new RuntimeException("线程池已经停止,拒绝任务");
        }
    }

    /**
     * 添加worker
     *
     * @param task 任务
     * @param core 是否是核心线程
     * @return
     */
    private boolean addWorker(Runnable task, boolean core) {
        if (State.get() == STOPPED) {
            return false;
        }
        out:
        while (true) {
            if (State.get() == STOPPED) {
                return false;
            }
            while (true) {
                //当前线程池中的worker数量是否达到了阈值
                if (WorkerCount.get() > (core ? corePoolSize : maxPoolSize)) {
                    //线程池不再允许创建新的worker了
                    return false;
                }
                //可以创建新的worker,worker数量原子性加1
                if (!casIncreaseWorkerCount()) {
                    continue out;
                }
                break out;
            }
        }
        //------ 实际添加worker的操作 --------
        if (State.get() == STOPPED) return false;
        Worker worker = new Worker(task);
        //拿到worker的工作线程
        Thread wt = worker.thread;
        if (wt != null) {
            if (wt.isAlive()) {
                throw new RuntimeException("这个工作线程不是线程池创建的,不归线程池管");
            }
            //启动线程
            wt.start();
            workers.add(worker);
        }
        return true;
    }

    private boolean casIncreaseWorkerCount() {
        return WorkerCount.compareAndSet(WorkerCount.get(), WorkerCount.get() + 1);
    }

    private boolean casDecreaseWorkerCount() {
        return WorkerCount.compareAndSet(WorkerCount.get(), WorkerCount.get() - 1);
    }

    /**
     * 从阻塞队列中拿任务
     */
    public Runnable getTask() {
        if (State.get() == STOPPED) return null;
        Runnable task = null;
        //自旋一直去队列里拿任务,拿不到任务就一直等,直到拿到任务
        while (true) {
            if (State.get() == STOPPED) return null;
            //当前线程池中的线程数量 < corePoolSize  => 核心线程 ,否则就是非核心线程
            //以核心线程的角色去获取任务
            if (WorkerCount.get() <= corePoolSize) {
                //核心线程
                /**
                 * take():获取并移除此队列的头部,在元素变得可用之前一直等待 。queue的长度 == 0 的时候,一直阻塞
                 */
                try {
                    task = taskQueue.take();
                } catch (InterruptedException e) {
                    return null;
                }
            } else {
                //非核心线程
                try {
                    task = taskQueue.poll(keepAliveTime, TimeUnit.SECONDS);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }

            }
            if (task != null) {
                return task;
            }
        }

    }

    /**
     * 关闭线程池
     */
    public void stop() {
        setState(STOPPED);
        interruptAllWorkers();
    }

    /**
     * 立即关闭线程池
     *
     * @param
     * @return 返回阻塞队列里没有执行完成的任务
     */
    public List<Runnable> stopNow() {
        List<Runnable> remains = new ArrayList<>();
        if (!taskQueue.isEmpty()) {
            //把队列中的所有任务放在reamins中
            taskQueue.drainTo(remains);
        }
        //避免同时其他线程仍然还继续给队列里添加任务,则处理干净
        while (!taskQueue.isEmpty()) {
            remains.add(taskQueue.poll());
        }
        return remains;
    }

    private void setState(int state) {
        while (true) {
            if (State.get() == state) {
                break;
            }
            if (State.compareAndSet(State.get(), state)) {
                break;
            }
        }
    }

    private void interruptAllWorkers() {
        for (Worker worker : workers) {
            if (!worker.thread.isInterrupted()) {
                worker.thread.interrupt();
            }
        }
    }

    /**
     * 任务线程
     */
    private final class Worker implements Runnable {

        //创建Worker的时候会带一个线程(传进来的任务线程)
        Runnable firstTask;
        //工作线程(当前worker线程)
        Thread thread;
        //完成任务数量
        int finishedTask = 0;

        Worker(Runnable firstTask) {
            this.firstTask = firstTask;
            this.thread = new Thread(this);
        }

        @Override
        public void run() {
            //从阻塞队列中不断的拿任务
            runWorker(this);
        }

        private void runWorker(Worker worker) {
            if (worker == null) throw new RuntimeException("worker不能为空");
            //拿到工作线程
            Thread wt = worker.thread;
            //worker的首要任务,这个首要任务在后续被执行
            Runnable task = worker.firstTask;
            worker.firstTask = null;
            try {
                /**
                 * task != null 任务线程
                 * task = getTask()) != null  从任务队列中取任务
                 *
                 * 先执行自己的任务,任务执行完之后再去看队列中是否有任务,有就继续执行,没有就跳出循环
                 */
                while (task != null || (task = getTask()) != null) {
                    if (wt.isInterrupted()) {
                        System.out.println("此工作线程已经被终止");
                        return;
                    }
                    if (State.get() == STOPPED) {
                        System.out.println("线程池已经关闭");
                        return;
                    }
                    task.run();
                    task = null;
                    worker.finishedTask++;
                }
            } finally {
                //worker的正常退出逻辑
                workers.remove(worker);
                if (casDecreaseWorkerCount()) {
                    finishedTaskCount += worker.finishedTask;
                }
            }
        }
    }
}

2,测试代码

public static void main(String[] args) {
        SimpleExecutor simpleExecutor = new SimpleExecutor(2, 2, 0);
        for(int i=0;i<10;i++){
            simpleExecutor.execute(new Runnable() {
                @Override
                public void run() {
                    System.out.println("线程名称:" + Thread.currentThread().getName());
                }
            });
        }
        try {
            Thread.sleep(2000);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        simpleExecutor.stop();
    }

3,效果展示

推荐阅读