首页 > 解决方案 > 带有 TFRecord 和压缩 GZIP 的 Beam Java SDK

问题描述

我们经常使用 Beam Java SDK(和 Google Cloud Dataflow 来运行批处理作业),当我们尝试使用TFRecordIOwith时发现了一些奇怪的东西(可能是错误?) Compression.GZIP。我们能够提出一些可以重现我们面临的错误的示例代码。

需要明确的是,我们使用的是 Beam Java SDK 2.4。

假设我们有PCollection<byte[]>一个可以是原始消息的 PC,例如,字节 [] 格式。我们通常使用 Base64 编码(换行符分隔的字符串)或使用 TFRecordIO(不压缩)将其写入 GCS(谷歌云存储)。很长一段时间以来,我们以这种方式从 GCS 读取数据都没有问题(前者为 2.5 年以上,后者为约 1.5 年)。

最近,我们尝试TFRecordIO使用Compression.GZIP选项,有时我们会得到一个异常,因为数据被视为无效(在被读取时)。数据本身(gzip 文件)没有损坏,我们已经测试了各种东西,并得出以下结论。

byte[]被压缩TFRecordIO的 a 高于某个阈值时(我会说是在 8192 或以上时),那么TFRecordIO.read().withCompression(Compression.GZIP)将不起作用。具体来说,它会抛出以下异常:

Exception in thread "main" java.lang.IllegalStateException: Invalid data
    at org.apache.beam.sdk.repackaged.com.google.common.base.Preconditions.checkState(Preconditions.java:444)
    at org.apache.beam.sdk.io.TFRecordIO$TFRecordCodec.read(TFRecordIO.java:642)
    at org.apache.beam.sdk.io.TFRecordIO$TFRecordSource$TFRecordReader.readNextRecord(TFRecordIO.java:526)
    at org.apache.beam.sdk.io.CompressedSource$CompressedReader.readNextRecord(CompressedSource.java:426)
    at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.advanceImpl(FileBasedSource.java:473)
    at org.apache.beam.sdk.io.FileBasedSource$FileBasedReader.startImpl(FileBasedSource.java:468)
    at org.apache.beam.sdk.io.OffsetBasedSource$OffsetBasedReader.start(OffsetBasedSource.java:261)
    at org.apache.beam.runners.direct.BoundedReadEvaluatorFactory$BoundedReadEvaluator.processElement(BoundedReadEvaluatorFactory.java:141)
    at org.apache.beam.runners.direct.DirectTransformExecutor.processElements(DirectTransformExecutor.java:161)
    at org.apache.beam.runners.direct.DirectTransformExecutor.run(DirectTransformExecutor.java:125)
    at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
    at java.util.concurrent.FutureTask.run(FutureTask.java:266)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

这可以很容易地复制,因此您可以参考最后的代码。您还将看到有关字节数组长度的评论(当我用各种大小进行测试时,我得出的结论是 8192 是幻数)。

所以我想知道这是一个错误还是已知问题——我在 Apache Beam 的问题跟踪器上找不到任何与此相关的内容,如果我需要查看另一个论坛/站点,请告诉我!如果这确实是一个错误,那么报告这个问题的正确渠道是什么?


以下代码可以重现我们遇到的错误。

成功运行(使用参数 1、39、100)最后会显示以下消息:

------------ counter metrics from CountDoFn
[counter]             plain_base64_proto_array_len: 8126
[counter]                    plain_base64_proto_in:   1
[counter]               plain_base64_proto_val_cnt:  39
[counter]              tfrecord_gz_proto_array_len: 8126
[counter]                     tfrecord_gz_proto_in:   1
[counter]                tfrecord_gz_proto_val_cnt:  39
[counter]          tfrecord_uncomp_proto_array_len: 8126
[counter]                 tfrecord_uncomp_proto_in:   1
[counter]            tfrecord_uncomp_proto_val_cnt:  39

如果参数 (1, 40, 100) 会将字节数组长度推到 8192 以上,它将抛出上述异常。

您可以调整参数(在CreateRandomProtoDataDoFn 内部)以了解byte[]压缩的长度为何重要。它也可以帮助您使用以下 protoc-gen Java 类(用于TestProto上面的主要代码。这里是:gist link

参考文献: 主要代码:

package exp.moloco.dataflow2.compression; // NOTE: Change appropriately.

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.TreeMap;

import org.apache.beam.runners.direct.DirectRunner;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.MetricResult;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.metrics.MetricsFilter;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.commons.codec.binary.Base64;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.protobuf.InvalidProtocolBufferException;

import com.moloco.dataflow.test.StackOverflow.TestProto;
import com.moloco.dataflow2.Main;

// @formatter:off
// This code uses TestProto (java class) that is generated by protoc.
// The message definition is as follows (in proto3, but it shouldn't matter):
// message TestProto {
//   int64 count = 1;
//   string name = 2;
//   repeated string values = 3;
// }
// Note that this code does not depend on whether this proto is used,
// or any other byte[] is used (see CreateRandomData DoFn later which generates the data being used in the code).
// We tested both, but are presenting this as a concrete example of how (our) code in production can be affected.
// @formatter:on

public class CompressionTester {
  private static final Logger LOG = LoggerFactory.getLogger(CompressionTester.class);

  static final List<String> lines = Arrays.asList("some dummy string that will not used in this job.");

  // Some GCS buckets where data will be written to.
  // %s will be replaced by some timestamped String for easy debugging.
  static final String PATH_TO_GCS_PLAIN_BASE64 = Main.SOME_BUCKET + "/comp-test/%s/output-plain-base64";
  static final String PATH_TO_GCS_TFRECORD_UNCOMP = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-uncompressed";
  static final String PATH_TO_GCS_TFRECORD_GZ = Main.SOME_BUCKET + "/comp-test/%s/output-tfrecord-gzip";

  // This DoFn reads byte[] which represents a proto message (TestProto).
  // It simply counts the number of proto objects it processes
  // as well as the number of Strings each proto object contains.
  // When the pipeline terminates, the values of the Counters will be printed out.
  static class CountDoFn extends DoFn<byte[], TestProto> {

    private final Counter protoIn;
    private final Counter protoValuesCnt;
    private final Counter protoByteArrayLength;

    public CountDoFn(String name) {
      protoIn = Metrics.counter(this.getClass(), name + "_proto_in");
      protoValuesCnt = Metrics.counter(this.getClass(), name + "_proto_val_cnt");
      protoByteArrayLength = Metrics.counter(this.getClass(), name + "_proto_array_len");
    }

    @ProcessElement
    public void processElement(ProcessContext c) throws InvalidProtocolBufferException {
      protoIn.inc();
      TestProto tp = TestProto.parseFrom(c.element());
      protoValuesCnt.inc(tp.getValuesCount());
      protoByteArrayLength.inc(c.element().length);
    }
  }

  // This DoFn emits a number of TestProto objects as byte[].
  // Input to this DoFn is ignored (not used).
  // Each TestProto object contains three fields: count (int64), name (string), and values (repeated string).
  // The three parameters in DoFn determines
  // (1) the number of proto objects to be generated,
  // (2) the number of (repeated) strings to be added to each proto object, and
  // (3) the length of (each) string.
  // TFRecord with Compression (when reading) fails when the parameters are 1, 40, 100, for instance.
  // TFRecord with Compression (when reading) succeeds when the parameters are 1, 39, 100, for instance.
  static class CreateRandomProtoData extends DoFn<String, byte[]> {

    static final int NUM_PROTOS = 1; // Total number of TestProto objects to be emitted by this DoFn.
    static final int NUM_STRINGS = 40; // Total number of strings in each TestProto object ('repeated string').
    static final int STRING_LEN = 100; // Length of each string object.

    // Returns a random string of length len.
    // For debugging purposes, the string only contains upper-case English alphabets.
    static String getRandomString(Random rd, int len) {
      StringBuffer sb = new StringBuffer();
      for (int i = 0; i < len; i++) {
        sb.append('A' + (rd.nextInt(26)));
      }
      return sb.toString();
    }

    // Returns a randomly generated TestProto object.
    // Each string is generated randomly using getRandomString().
    static TestProto getRandomProto(Random rd) {
      TestProto.Builder tpBuilder = TestProto.newBuilder();

      tpBuilder.setCount(rd.nextInt());
      tpBuilder.setName(getRandomString(rd, STRING_LEN));
      for (int i = 0; i < NUM_STRINGS; i++) {
        tpBuilder.addValues(getRandomString(rd, STRING_LEN));
      }

      return tpBuilder.build();
    }

    // Emits TestProto objects are byte[].
    @ProcessElement
    public void processElement(ProcessContext c) {
      // For debugging purposes, we set the seed here.
      Random rd = new Random();
      rd.setSeed(132475);

      for (int n = 0; n < NUM_PROTOS; n++) {
        byte[] data = getRandomProto(rd).toByteArray();
        c.output(data);
        // With parameters (1, 39, 100), the array length is 8126. It works fine.
        // With parameters (1, 40, 100), the array length is 8329. It breaks TFRecord with GZIP.
        System.out.println("\n--------------------------\n");
        System.out.println("byte array length = " + data.length);
        System.out.println("\n--------------------------\n");
      }
    }
  }

  public static void execute() {
    PipelineOptions options = PipelineOptionsFactory.create();
    options.setJobName("compression-tester");
    options.setRunner(DirectRunner.class);

    // For debugging purposes, write files under 'gcsSubDir' so we can easily distinguish.
    final String gcsSubDir =
        String.format("%s-%d", DateTime.now(DateTimeZone.UTC), DateTime.now(DateTimeZone.UTC).getMillis());

    // Write PCollection<TestProto> in 3 different ways to GCS.
    {
      Pipeline pipeline = Pipeline.create(options);

      // Create dummy data which is a PCollection of byte arrays (each array representing a proto message).
      PCollection<byte[]> data = pipeline.apply(Create.of(lines)).apply(ParDo.of(new CreateRandomProtoData()));

      // 1. Write as plain-text with base64 encoding.
      data.apply(ParDo.of(new DoFn<byte[], String>() {
        @ProcessElement
        public void processElement(ProcessContext c) {
          c.output(new String(Base64.encodeBase64(c.element())));
        }
      })).apply(TextIO.write().to(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir)).withNumShards(1));

      // 2. Write as TFRecord.
      data.apply(TFRecordIO.write().to(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir)).withNumShards(1));

      // 3. Write as TFRecord-gzip.
      data.apply(TFRecordIO.write().withCompression(Compression.GZIP)
          .to(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir)).withNumShards(1));

      pipeline.run().waitUntilFinish();
    }

    LOG.info("-------------------------------------------");
    LOG.info("               READ TEST BEGINS ");
    LOG.info("-------------------------------------------");

    // Read PCollection<TestProto> in 3 different ways from GCS.
    {
      Pipeline pipeline = Pipeline.create(options);

      // 1. Read as plain-text.
      pipeline.apply(TextIO.read().from(String.format(PATH_TO_GCS_PLAIN_BASE64, gcsSubDir) + "*"))
          .apply(ParDo.of(new DoFn<String, byte[]>() {
            @ProcessElement
            public void processElement(ProcessContext c) {
              c.output(Base64.decodeBase64(c.element()));
            }
          })).apply("plain-base64", ParDo.of(new CountDoFn("plain_base64")));

      // 2. Read as TFRecord -> byte array.
      pipeline.apply(TFRecordIO.read().from(String.format(PATH_TO_GCS_TFRECORD_UNCOMP, gcsSubDir) + "*"))
          .apply("tfrecord-uncomp", ParDo.of(new CountDoFn("tfrecord_uncomp")));

      // 3. Read as TFRecord-gz -> byte array.
      // This seems to fail when 'data size' becomes large.
      pipeline
          .apply(TFRecordIO.read().withCompression(Compression.GZIP)
              .from(String.format(PATH_TO_GCS_TFRECORD_GZ, gcsSubDir) + "*"))
          .apply("tfrecord_gz", ParDo.of(new CountDoFn("tfrecord_gz")));

      // 4. Run pipeline.
      PipelineResult res = pipeline.run();
      res.waitUntilFinish();

      // Check CountDoFn's metrics.
      // The numbers should match.
      Map<String, Long> counterValues = new TreeMap<String, Long>();
      for (MetricResult<Long> counter : res.metrics().queryMetrics(MetricsFilter.builder().build()).counters()) {
        counterValues.put(counter.name().name(), counter.committed());
      }
      StringBuffer sb = new StringBuffer();
      sb.append("\n------------ counter metrics from CountDoFn\n");
      for (Entry<String, Long> entry : counterValues.entrySet()) {
        sb.append(String.format("[counter] %40s: %5d\n", entry.getKey(), entry.getValue()));
      }
      LOG.info(sb.toString());
    }
  }
}

标签: gzipgoogle-cloud-dataflowapache-beamtfrecord

解决方案


这看起来很像 TFRecordIO 中的一个错误。Channel.read()可以读取的字节数少于输入缓冲区的容量。8192 似乎是GzipCompressorInputStream. 我提交了https://issues.apache.org/jira/browse/BEAM-5412


推荐阅读