首页 > 解决方案 > 在java中使用线程优化程序

问题描述

我的目标是使用 Java 计算二叉树中元素的总和,ExecutorService然后使用CompletionService.

用户给出树的高度、并行度应该开始的级别以及要使用的线程数。我知道ExecutorService应该产生与用户给它的线程数完全相同的线程,并且完成服务应该在preProcess方法中产生 N 个任务,其中 N 是 2^(并行度),因为在某个级别,n ,我们将有 2^n 个节点。

我的问题是我不知道如何从给定高度开始遍历树以及如何CompletionService使用postProcess方法中的结果。此外,每次产生新任务时,总任务数增加一,每次CompletionService返回结果时,任务数应减少一。

我能够CompletionServiceprocessTreeParallel函数中使用,但我真的不明白如何在我的postProcess方法中使用它。

这是我的代码:

import java.util.concurrent.*;

public class TreeCalculation {
    // tree level to go parallel
    int levelParallel;
    // total number of generated tasks
    long totalTasks;
    // current number of open tasks
    long nTasks;
    // total height of tree
    int height;
    // Executors
    ExecutorService exec;
    CompletionService<Long> cs;
    TreeCalculation(int height, int levelParallel) {
        this.height = height;
        this.levelParallel = levelParallel;
    }

    void incrementTasks() {
        ++nTasks;
        ++totalTasks;
    }

    void decrementTasks() {
        --nTasks;
    }

    long getNTasks() {
        return nTasks;
    }
    // Where the ExecutorService should be initialized 
    // with a specific threadCount
    void preProcess(int threadCount) {

        exec = Executors.newFixedThreadPool(threadCount);
        cs = new ExecutorCompletionService<Long>(exec);
        nTasks = 0;
        totalTasks = 0;
    }
    // Where the CompletionService should collect the results;
    long postProcess() {
        long result = 0;
        return result;
    }

    public static void main(String[] args) {
        if (args.length != 3) {
            System.out.println(
                "usage: java Tree treeHeight levelParallel nthreads\n");
            return;
        }
        int height = Integer.parseInt(args[0]);
        int levelParallel = Integer.parseInt(args[1]);
        int threadCount = Integer.parseInt(args[2]);

        TreeCalculation tc = new TreeCalculation(height, levelParallel);

        // generate balanced binary tree
        Tree t = Tree.genTree(height, height);

        //System.gc();

        // traverse sequential
        long t0 = System.nanoTime();
        long p1 = t.processTree();
        double t1 = (System.nanoTime() - t0) * 1e-9;

        t0 = System.nanoTime();
        tc.preProcess(threadCount);
        long p2 = t.processTreeParallel(tc);
        p2 += tc.postProcess();
        double t2 = (System.nanoTime() - t0) * 1e-9;

        long ref = (Tree.counter * (Tree.counter + 1)) / 2;
        if (p1 != ref)
            System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
        if (p1 != p2)
            System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
        if (tc.totalTasks != (2 << levelParallel)) {
            System.out.printf("ERROR: ntasks %d != %d\n", 
                2 << levelParallel, tc.totalTasks);
        }

        // print timing
        System.out.printf("tree height: %2d "
            + "sequential: %.6f "
            + "parallel with %3d threads and %6d tasks: %.6f  "
            + "speedup: %.3f count: %d\n",
            height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
    }
}

// ============================================================================

class Tree {

    static long counter; // counter for consecutive node numbering

    int level; // node level
    long value; // node value
    Tree left; // left child
    Tree right; // right child

    // constructor
    Tree(long value) {
        this.value = value;
    }

    // generate a balanced binary tree of depth k
    static Tree genTree(int k, int height) {
        if (k < 0) {
            return null;
        } else {
            Tree t = new Tree(++counter);
            t.level = height - k;
            t.left = genTree(k - 1, height);
            t.right = genTree(k - 1, height);
            return t;
        }
    }

    // ========================================================================
    // traverse a tree sequentially

    long processTree() {
        return value
            + ((left == null) ? 0 : left.processTree())
            + ((right == null) ? 0 : right.processTree());
    }

    // ========================================================================
    // traverse a tree parallel
    // This is where I was able to use the CompletionService
    long processTreeParallel(TreeCalculation tc) {

        tc.totalTasks = 0;
        for(long i =0; i<(long)Math.pow(tc.levelParallel, 2); i++)
        {
            tc.incrementTasks();
            tc.cs.submit(new Callable<Long>(){
                @Override
                public Long call() throws Exception {
                    return processTree();
                }

            });
        }
        Long result = Long.valueOf(0);
        for(int i=0; i<(long)Math.pow(2,tc.levelParallel); i++) {
            try{
                result += tc.cs.take().get();
                tc.decrementTasks();
            }catch(Exception e){}

        }
        return result;
    }
}

标签: javamultithreadingbinary-tree

解决方案


The basic idea here is that you traverse the tree, and compute the results just like you did in the processTree method. But as soon as the level is reached at which the parallel computation is supposed to start (the levelParallel), you just spawn a task that actually calls processTree internally. This will take care of the remaining part of the tree.

processTreeParallel             0
                               / \    
                              /   \    
processTreeParallel          1     2
                            / \   / \    
processTreeParallel        3   4 5   6  <- levelParallel
                           |   | |   |
processTree call for each: v   v v   v
                          +---------------+
tasks for executor:       |T   T T   T    |
                          +---------------+
completion service         |
fetches tasks and          v
sums them up:              T+T+T+T  -> result

You then have to add the result that was computed by the sequential part of the processTreeParallel method, and the task-results that are summed up by the completion service.

The processTreeParallel method could thus be implemented like this:

long processTreeParallel(TreeCalculation tc)
{
    if (level < tc.levelParallel)
    {
        long leftResult = left.processTreeParallel(tc);
        long rightResult = right.processTreeParallel(tc);
        return value + leftResult + rightResult;
    }
    tc.incrementTasks();
    tc.cs.submit(new Callable<Long>()
    {
        @Override
        public Long call() throws Exception
        {
            return processTree();
        }
    });
    return 0;
}

The complete program is shown here:

import java.util.concurrent.*;

public class TreeCalculation
{
    // tree level to go parallel
    int levelParallel;
    // total number of generated tasks
    long totalTasks;
    // current number of open tasks
    long nTasks;
    // total height of tree
    int height;
    // Executors
    ExecutorService exec;
    CompletionService<Long> cs;

    TreeCalculation(int height, int levelParallel)
    {
        this.height = height;
        this.levelParallel = levelParallel;
    }

    void incrementTasks()
    {
        ++nTasks;
        ++totalTasks;
    }

    void decrementTasks()
    {
        --nTasks;
    }

    long getNTasks()
    {
        return nTasks;
    }

    // Where the ExecutorService should be initialized
    // with a specific threadCount
    void preProcess(int threadCount)
    {
        exec = Executors.newFixedThreadPool(threadCount);
        cs = new ExecutorCompletionService<Long>(exec);
        nTasks = 0;
        totalTasks = 0;
    }

    // Where the CompletionService should collect the results;
    long postProcess()
    {
        exec.shutdown();
        long result = 0;
        for (int i = 0; i < (long) Math.pow(2, levelParallel); i++)
        {
            try
            {
                result += cs.take().get();
                decrementTasks();
            }
            catch (Exception e)
            {
                e.printStackTrace();
            }
        }
        return result;
    }

    public static void main(String[] args)
    {

        int height = 22;
        int levelParallel = 3;
        int threadCount = 4;
        if (args.length != 3)
        {
            System.out.println(
                "usage: java Tree treeHeight levelParallel nthreads\n");
            System.out.println("Using default values for test");
        }
        else
        {
            height = Integer.parseInt(args[0]);
            levelParallel = Integer.parseInt(args[1]);
            threadCount = Integer.parseInt(args[2]);

        }

        TreeCalculation tc = new TreeCalculation(height, levelParallel);

        // generate balanced binary tree
        Tree t = Tree.genTree(height, height);

        // traverse sequential
        long t0 = System.nanoTime();
        long p1 = t.processTree();
        double t1 = (System.nanoTime() - t0) * 1e-9;

        t0 = System.nanoTime();
        tc.preProcess(threadCount);
        long p2 = t.processTreeParallel(tc);
        p2 += tc.postProcess();
        double t2 = (System.nanoTime() - t0) * 1e-9;

        long ref = (Tree.counter * (Tree.counter + 1)) / 2;
        if (p1 != ref)
            System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
        if (p1 != p2)
            System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
        if (tc.totalTasks != (1 << levelParallel))
        {
            System.out.printf("ERROR: ntasks %d != %d\n", 1 << levelParallel,
                tc.totalTasks);
        }

        // print timing
        System.out.printf("tree height: %2d\n" 
            + "sequential: %.6f\n"
            + "parallel with %3d threads and %6d tasks: %.6f\n"
            + "speedup: %.3f count: %d\n",
            height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
    }
}

// ============================================================================

class Tree
{

    static long counter; // counter for consecutive node numbering

    int level; // node level
    long value; // node value
    Tree left; // left child
    Tree right; // right child

    // constructor
    Tree(long value)
    {
        this.value = value;
    }

    // generate a balanced binary tree of depth k
    static Tree genTree(int k, int height)
    {
        if (k < 0)
        {
            return null;
        }

        Tree t = new Tree(++counter);
        t.level = height - k;
        t.left = genTree(k - 1, height);
        t.right = genTree(k - 1, height);
        return t;
    }

    // ========================================================================
    // traverse a tree sequentially

    long processTree()
    {
        return value 
            + ((left == null) ? 0 : left.processTree())
            + ((right == null) ? 0 : right.processTree());
    }

    // ========================================================================
    // traverse a tree parallel
    long processTreeParallel(TreeCalculation tc)
    {
        if (level < tc.levelParallel)
        {
            long leftResult = left.processTreeParallel(tc);
            long rightResult = right.processTreeParallel(tc);
            return value + leftResult + rightResult;
        }
        tc.incrementTasks();
        tc.cs.submit(new Callable<Long>()
        {
            @Override
            public Long call() throws Exception
            {
                return processTree();
            }
        });
        return 0;
    }
}

推荐阅读