首页 > 解决方案 > 设计考虑 - 在一定数量的失败后关闭 ForkJoinPool

问题描述

我有一个清单entities。我正在forEach此列表上执行并行操作,并对每个entity. 我forEach在 a 内运行并行,ForkJoinPool以便实现所需的并行性。

我现有代码的大纲如下:

ForkJoinPool pool = new ForkJoinPool(4);
Consumer<Entity> consumer = (Entity entity) -> {
    try {
        doSomething(entity);
    } catch(Exception cause) {

    }
};

try {
    pool.submit(() -> {
        entities.stream()
                .parallel()
                .forEach(consumer);
    }).get();
} finally {
    pool.shutdown();
}

asdoSomething()方法可能由于某种原因抛出异常,例如网络连接失败;如果连续错误的数量达到错误阈值,我想停止并行处理。所以我想到了以下大纲:

int errorThreshold = 200;
AtomicInteger errorCount = new AtomicInteger(0);
ForkJoinPool pool = new ForkJoinPool(parallelism);
Consumer<Entity> consumer = (Entity entity) -> {
    boolean success = false;

    try {
        doSomething(entity);
        success = true;
    } catch(Exception cause) {        

    }

    if (!success) {
        if (errorCount.incrementAndGet() == errorThreshold) {
            pool.shutdownNow();
        }
    } else {
        errorCount.set(0);
    }
};

try {
    pool.submit(() -> {
        entities.stream()
                .parallel()
                .forEach(consumer);
    }).get();
} finally {
    pool.shutdown();
}

这是实现我想要的最佳方式吗?

PS:我用的是jdk8。

更新

示例代码:

import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class Main {

    private class Entity {
        private int id;

        public Entity(int id) {
            this.id = id;
        }
    }

    private void execute() throws Exception {
        List<Entity> entities = IntStream.range(0, 1000)
                .mapToObj(Entity::new)
                .collect(Collectors.toList());
        int errorThreshold = 5;
        AtomicInteger errorCount = new AtomicInteger(0);
        ForkJoinPool pool = new ForkJoinPool(4);
        Consumer<Entity> consumer = (Entity entity) -> {
            boolean success = false;

            try {
                doSomething(entity);
            } catch (Exception cause) {
                System.err.println(cause.getMessage());
            }

            if (!success) {
                if (errorCount.incrementAndGet() == errorThreshold) {
                    pool.shutdownNow();
                }
            } else {
                errorCount.set(0);
            }
        };

        try {
            pool.submit(() -> entities
                    .stream()
                    .parallel()
                    .forEach(consumer))
                    .get();
        } catch (Exception cause) {
            if (CancellationException.class.isInstance(cause)) {
                System.out.println("ForkJoinPool stopped due to consecutive error");
            } else {
                throw cause;
            }
        } finally {
            if (!pool.isTerminated()) {
                pool.shutdown();
            }
        }
    }

    private void doSomething(Entity entity) {
        if (isPrime(entity.id)) {
            throw new RuntimeException("Exception occurred for ID: " + entity.id);
        }

        System.out.println(entity.id);
    }

    private boolean isPrime(int n) {
        if (n == 2) {
            return true;
        }

        if (n == 0 || n == 1 || n % 2 == 0) {
            return false;
        }

        int limit = (int) Math.ceil(Math.sqrt(n));

        for (int i = 3; i <= limit; i += 2) {
            if (n % i == 0) {
                return false;
            }
        }

        return true;
    }

    public static void main(String[] args) throws Exception {
        new Main().execute();
    }
}

标签: javaerror-handlingconcurrencyjava-streamforkjoinpool

解决方案


这是我提出的解决方案:

import java.util.List;
import java.util.Objects;
import java.util.Spliterator;
import java.util.Spliterators.AbstractSpliterator;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class Main {

    private class Entity {
        private int id;

        public Entity(int id) {
            this.id = id;
        }
    }

    private void execute() throws Exception {
        List<Entity> entities = IntStream.range(0, 1000)
                .mapToObj(Entity::new)
                .collect(Collectors.toList());
        int errorThreshold = 40;
        LongAdder errorCountAdder = new LongAdder();
        ForkJoinPool pool = new ForkJoinPool(4);
        Function<Entity, String> mapper = (Entity entity) -> {
            String processingStoppedMessage = null;
            boolean success = true;

            try {
                doSomething(entity);
            } catch (Exception cause) {
                System.err.println(cause.getMessage());
                success = false;
            }

            if (!success) {
                errorCountAdder.increment();

                if (errorCountAdder.intValue() == errorThreshold) {
                    processingStoppedMessage = String.format("Processing stopped due to %1$d consecutive error", errorCountAdder.intValue());
                }
            } else {
                errorCountAdder.reset();
            }

            return processingStoppedMessage;
        };

        try {
            Spliterator<Entity> originalSpliterator = entities.spliterator();
            int estimatedSplitSize = errorThreshold / pool.getParallelism();
            Spliterator<Entity> stoppableSpliterator = new AbstractSpliterator<Entity>(estimatedSplitSize, Spliterator.CONCURRENT & Spliterator.SUBSIZED) {

                @Override
                public boolean tryAdvance(Consumer<? super Entity> action) {
                    return (errorCountAdder.intValue() == errorThreshold) ? false : originalSpliterator.tryAdvance(action);
                }
            };
            Stream<Entity> stream = StreamSupport.stream(stoppableSpliterator, true);
            //@formatter:off
            String message = pool.submit(
                        () -> stream.map(mapper)
                                .filter(Objects::nonNull)
                                .findAny()
                                .orElse(null)
                    )
                    .get();
            //@formatter:on
            if (Objects.nonNull(message)) {
                System.out.println(message);
            }
        } finally {
            if (!pool.isTerminated()) {
                pool.shutdown();
                //@formatter:off
                while (!pool.isTerminated()) { }
                //@formatter:on
            }
        }
    }

    private void doSomething(Entity entity) {
        if (isInvalid(entity.id)) {
            throw new RuntimeException("Exception occurred for ID: " + entity.id);
        }    
    }

    private boolean isInvalid(int n) {
        if (n > 100) {
            return true;
        }

        return false;
    }

    public static void main(String[] args) throws Exception {
        new Main().execute();
    }
}

我观察到几件事:

  • 如果我调用操作不会立即pool.shutdownNow()停止Stream
  • 如果我没有Spliterator使用估计的大小实现自定义errorThreshold / parallelism,即使我调用终端操作findAny(),流仍然会继续进行,然后停止

推荐阅读