java - 在java中使用线程优化程序
问题描述
我的目标是使用 Java 计算二叉树中元素的总和,ExecutorService
然后使用CompletionService
.
用户给出树的高度、并行度应该开始的级别以及要使用的线程数。我知道ExecutorService
应该产生与用户给它的线程数完全相同的线程,并且完成服务应该在preProcess
方法中产生 N 个任务,其中 N 是 2^(并行度),因为在某个级别,n ,我们将有 2^n 个节点。
我的问题是我不知道如何从给定高度开始遍历树以及如何CompletionService
使用postProcess
方法中的结果。此外,每次产生新任务时,总任务数增加一,每次CompletionService
返回结果时,任务数应减少一。
我能够CompletionService
在processTreeParallel
函数中使用,但我真的不明白如何在我的postProcess
方法中使用它。
这是我的代码:
import java.util.concurrent.*;
public class TreeCalculation {
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel) {
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks() {
++nTasks;
++totalTasks;
}
void decrementTasks() {
--nTasks;
}
long getNTasks() {
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount) {
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess() {
long result = 0;
return result;
}
public static void main(String[] args) {
if (args.length != 3) {
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
return;
}
int height = Integer.parseInt(args[0]);
int levelParallel = Integer.parseInt(args[1]);
int threadCount = Integer.parseInt(args[2]);
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
//System.gc();
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (2 << levelParallel)) {
System.out.printf("ERROR: ntasks %d != %d\n",
2 << levelParallel, tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d "
+ "sequential: %.6f "
+ "parallel with %3d threads and %6d tasks: %.6f "
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree {
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value) {
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height) {
if (k < 0) {
return null;
} else {
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
}
// ========================================================================
// traverse a tree sequentially
long processTree() {
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
// This is where I was able to use the CompletionService
long processTreeParallel(TreeCalculation tc) {
tc.totalTasks = 0;
for(long i =0; i<(long)Math.pow(tc.levelParallel, 2); i++)
{
tc.incrementTasks();
tc.cs.submit(new Callable<Long>(){
@Override
public Long call() throws Exception {
return processTree();
}
});
}
Long result = Long.valueOf(0);
for(int i=0; i<(long)Math.pow(2,tc.levelParallel); i++) {
try{
result += tc.cs.take().get();
tc.decrementTasks();
}catch(Exception e){}
}
return result;
}
}
解决方案
The basic idea here is that you traverse the tree, and compute the results just like you did in the processTree
method. But as soon as the level is reached at which the parallel computation is supposed to start (the levelParallel
), you just spawn a task that actually calls processTree
internally. This will take care of the remaining part of the tree.
processTreeParallel 0
/ \
/ \
processTreeParallel 1 2
/ \ / \
processTreeParallel 3 4 5 6 <- levelParallel
| | | |
processTree call for each: v v v v
+---------------+
tasks for executor: |T T T T |
+---------------+
completion service |
fetches tasks and v
sums them up: T+T+T+T -> result
You then have to add the result that was computed by the sequential part of the processTreeParallel
method, and the task-results that are summed up by the completion service.
The processTreeParallel
method could thus be implemented like this:
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
The complete program is shown here:
import java.util.concurrent.*;
public class TreeCalculation
{
// tree level to go parallel
int levelParallel;
// total number of generated tasks
long totalTasks;
// current number of open tasks
long nTasks;
// total height of tree
int height;
// Executors
ExecutorService exec;
CompletionService<Long> cs;
TreeCalculation(int height, int levelParallel)
{
this.height = height;
this.levelParallel = levelParallel;
}
void incrementTasks()
{
++nTasks;
++totalTasks;
}
void decrementTasks()
{
--nTasks;
}
long getNTasks()
{
return nTasks;
}
// Where the ExecutorService should be initialized
// with a specific threadCount
void preProcess(int threadCount)
{
exec = Executors.newFixedThreadPool(threadCount);
cs = new ExecutorCompletionService<Long>(exec);
nTasks = 0;
totalTasks = 0;
}
// Where the CompletionService should collect the results;
long postProcess()
{
exec.shutdown();
long result = 0;
for (int i = 0; i < (long) Math.pow(2, levelParallel); i++)
{
try
{
result += cs.take().get();
decrementTasks();
}
catch (Exception e)
{
e.printStackTrace();
}
}
return result;
}
public static void main(String[] args)
{
int height = 22;
int levelParallel = 3;
int threadCount = 4;
if (args.length != 3)
{
System.out.println(
"usage: java Tree treeHeight levelParallel nthreads\n");
System.out.println("Using default values for test");
}
else
{
height = Integer.parseInt(args[0]);
levelParallel = Integer.parseInt(args[1]);
threadCount = Integer.parseInt(args[2]);
}
TreeCalculation tc = new TreeCalculation(height, levelParallel);
// generate balanced binary tree
Tree t = Tree.genTree(height, height);
// traverse sequential
long t0 = System.nanoTime();
long p1 = t.processTree();
double t1 = (System.nanoTime() - t0) * 1e-9;
t0 = System.nanoTime();
tc.preProcess(threadCount);
long p2 = t.processTreeParallel(tc);
p2 += tc.postProcess();
double t2 = (System.nanoTime() - t0) * 1e-9;
long ref = (Tree.counter * (Tree.counter + 1)) / 2;
if (p1 != ref)
System.out.printf("ERROR: sum %d != reference %d\n", p1, ref);
if (p1 != p2)
System.out.printf("ERROR: sum %d != parallel %d\n", p1, p2);
if (tc.totalTasks != (1 << levelParallel))
{
System.out.printf("ERROR: ntasks %d != %d\n", 1 << levelParallel,
tc.totalTasks);
}
// print timing
System.out.printf("tree height: %2d\n"
+ "sequential: %.6f\n"
+ "parallel with %3d threads and %6d tasks: %.6f\n"
+ "speedup: %.3f count: %d\n",
height, t1, threadCount, tc.totalTasks, t2, t1 / t2, ref);
}
}
// ============================================================================
class Tree
{
static long counter; // counter for consecutive node numbering
int level; // node level
long value; // node value
Tree left; // left child
Tree right; // right child
// constructor
Tree(long value)
{
this.value = value;
}
// generate a balanced binary tree of depth k
static Tree genTree(int k, int height)
{
if (k < 0)
{
return null;
}
Tree t = new Tree(++counter);
t.level = height - k;
t.left = genTree(k - 1, height);
t.right = genTree(k - 1, height);
return t;
}
// ========================================================================
// traverse a tree sequentially
long processTree()
{
return value
+ ((left == null) ? 0 : left.processTree())
+ ((right == null) ? 0 : right.processTree());
}
// ========================================================================
// traverse a tree parallel
long processTreeParallel(TreeCalculation tc)
{
if (level < tc.levelParallel)
{
long leftResult = left.processTreeParallel(tc);
long rightResult = right.processTreeParallel(tc);
return value + leftResult + rightResult;
}
tc.incrementTasks();
tc.cs.submit(new Callable<Long>()
{
@Override
public Long call() throws Exception
{
return processTree();
}
});
return 0;
}
}
推荐阅读
- bash - 大文件中的重复数据删除行因 sort 和 uniq 而失败
- java - 如何在eclipse中使用java从不同的包中访问文件?
- nlog - 发送到非异步目标的日志事件会立即被记录吗?
- f# - 如何使用 F# 制作一个空的 catch 块?
- node.js - 无法让谷歌 STT API 在嵌入式机器上工作 (AArch64)
- javascript - 更改变量值时数据属性不会更改
- java - 每小时用新数据更新一次 JLabel
- logging - 是否可以在 .NET Core ConsoleLogger 和 DebugLogger 中禁用类别输出?
- file - 如何更改分配给 IPA 文件的 Apple ID?
- python - 如果在 html 中检查变量的 else 条件为空