diff --git a/CHANGES.md b/CHANGES.md index d5f9fdd89e07..7a57db430614 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -58,6 +58,7 @@ * ReadFromMongoDB/WriteToMongoDB will mask password in display_data (Python) ([BEAM-11444](https://issues.apache.org/jira/browse/BEAM-11444).) * Support for X source added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)). * There is a new transform `ReadAllFromBigQuery` that can receive multiple requests to read data from BigQuery at pipeline runtime. See [PR 13170](https://github.com/apache/beam/pull/13170), and [BEAM-9650](https://issues.apache.org/jira/browse/BEAM-9650). +* ParquetIO can now read files with an unknown schema. See [PR-13554](https://github.com/apache/beam/pull/13554) and ([BEAM-11460](https://issues.apache.org/jira/browse/BEAM-11460)) ## New Features / Improvements diff --git a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java index 5dd9a11d76d7..74099cff3875 100644 --- a/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java +++ b/sdks/java/io/parquet/src/main/java/org/apache/beam/sdk/io/parquet/ParquetIO.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.parquet; import static java.lang.String.format; +import static org.apache.parquet.Preconditions.checkArgument; import static org.apache.parquet.Preconditions.checkNotNull; import static org.apache.parquet.hadoop.ParquetFileWriter.Mode.OVERWRITE; @@ -38,21 +39,30 @@ import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; import org.apache.beam.sdk.coders.AvroCoder; +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.FileIO.ReadableFile; import org.apache.beam.sdk.io.fs.ResourceId; import org.apache.beam.sdk.io.hadoop.SerializableConfiguration; +import org.apache.beam.sdk.io.parquet.ParquetIO.ReadFiles.ReadFn; +import org.apache.beam.sdk.io.parquet.ParquetIO.ReadFiles.SplitReadFn; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.options.ValueProvider; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.hadoop.conf.Configuration; @@ -152,6 +162,44 @@ * * * } * + *

Reading records of an unknown schema

+ * + *

To read records from files whose schema is unknown at pipeline construction time or differs + * between files, use {@link #parseGenericRecords(SerializableFunction)} - in this case, you will + * need to specify a parsing function for converting each {@link GenericRecord} into a value of your + * custom type. + * + *

For example: + * + *

{@code
+ * Pipeline p = ...;
+ *
+ * PCollection records =
+ *     p.apply(
+ *       ParquetIO.parseGenericRecords(
+ *           new SerializableFunction() {
+ *               public Foo apply(GenericRecord record) {
+ *                   // If needed, access the schema of the record using record.getSchema()
+ *                   return ...;
+ *               }
+ *           })
+ *           .setFilePattern(...));
+ *
+ * // For reading from files
+ *  PCollection files = p.apply(...);
+ *
+ *  PCollection records =
+ *     files
+ *       .apply(
+ *           ParquetIO.parseFilesGenericRecords(
+ *               new SerializableFunction() {
+ *                   public Foo apply(GenericRecord record) {
+ *                       // If needed, access the schema of the record using record.getSchema()
+ *                       return ...;
+ *                   }
+ *           }));
+ * }
+ * *

Writing Parquet files

* *

{@link ParquetIO.Sink} allows you to write a {@link PCollection} of {@link GenericRecord} into @@ -202,7 +250,30 @@ public static Read read(Schema schema) { */ public static ReadFiles readFiles(Schema schema) { return new AutoValue_ParquetIO_ReadFiles.Builder() + .setSplittable(false) .setSchema(schema) + .build(); + } + + /** + * Reads {@link GenericRecord} from a Parquet file (or multiple Parquet files matching the + * pattern) and converts to user defined type using provided parseFn. + */ + public static Parse parseGenericRecords(SerializableFunction parseFn) { + return new AutoValue_ParquetIO_Parse.Builder() + .setParseFn(parseFn) + .setSplittable(false) + .build(); + } + + /** + * Reads {@link GenericRecord} from Parquet files and converts to user defined type using provided + * {@code parseFn}. + */ + public static ParseFiles parseFilesGenericRecords( + SerializableFunction parseFn) { + return new AutoValue_ParquetIO_ParseFiles.Builder() + .setParseFn(parseFn) .setSplittable(false) .build(); } @@ -300,6 +371,121 @@ public void populateDisplayData(DisplayData.Builder builder) { } } + /** Implementation of {@link #parseGenericRecords(SerializableFunction)}. */ + @AutoValue + public abstract static class Parse extends PTransform> { + abstract @Nullable ValueProvider getFilepattern(); + + abstract SerializableFunction getParseFn(); + + abstract boolean isSplittable(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFilepattern(ValueProvider inputFiles); + + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setSplittable(boolean splittable); + + abstract Parse build(); + } + + public Parse from(ValueProvider inputFiles) { + return toBuilder().setFilepattern(inputFiles).build(); + } + + public Parse from(String inputFiles) { + return from(ValueProvider.StaticValueProvider.of(inputFiles)); + } + + public Parse withSplit() { + return toBuilder().setSplittable(true).build(); + } + + @Override + public PCollection expand(PBegin input) { + checkNotNull(getFilepattern(), "Filepattern cannot be null."); + return input + .apply("Create filepattern", Create.ofProvider(getFilepattern(), StringUtf8Coder.of())) + .apply(FileIO.matchAll()) + .apply(FileIO.readMatches()) + .apply( + parseFilesGenericRecords(getParseFn()) + .toBuilder() + .setSplittable(isSplittable()) + .build()); + } + } + + /** Implementation of {@link #parseFilesGenericRecords(SerializableFunction)}. */ + @AutoValue + public abstract static class ParseFiles + extends PTransform, PCollection> { + + abstract SerializableFunction getParseFn(); + + abstract boolean isSplittable(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setParseFn(SerializableFunction parseFn); + + abstract Builder setSplittable(boolean split); + + abstract ParseFiles build(); + } + + public ParseFiles withSplit() { + return toBuilder().setSplittable(true).build(); + } + + @Override + public PCollection expand(PCollection input) { + checkArgument(!isGenericRecordOutput(), "Parse can't be used for reading as GenericRecord."); + + PCollection parsedRecords = + isSplittable() + ? input.apply(ParDo.of(new SplitReadFn<>(null, null, getParseFn()))) + : input.apply(ParDo.of(new ReadFn<>(null, getParseFn()))); + + return parsedRecords.setCoder(inferCoder(input.getPipeline().getCoderRegistry())); + } + + /** Returns true if expected output is {@code PCollection}. */ + private boolean isGenericRecordOutput() { + String outputType = TypeDescriptors.outputOf(getParseFn()).getType().getTypeName(); + return outputType.equals(GenericRecord.class.getTypeName()); + } + + /** + * Identifies the {@code Coder} to be used for the output PCollection. + * + *

Returns {@link AvroCoder} if expected output is {@link GenericRecord}. + * + * @param coderRegistry the {@link org.apache.beam.sdk.Pipeline}'s CoderRegistry to identify + * Coder for expected output type of {@link #getParseFn()} + */ + private Coder inferCoder(CoderRegistry coderRegistry) { + if (isGenericRecordOutput()) { + throw new IllegalArgumentException("Parse can't be used for reading as GenericRecord."); + } + + // If not GenericRecord infer it from ParseFn. + try { + return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn())); + } catch (CannotProvideCoderException e) { + throw new IllegalArgumentException( + "Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().", + e); + } + } + } + /** Implementation of {@link #readFiles(Schema)}. */ @AutoValue public abstract static class ReadFiles @@ -357,26 +543,35 @@ public PCollection expand(PCollection input) if (isSplittable()) { Schema coderSchema = getProjectionSchema() == null ? getSchema() : getEncoderSchema(); return input - .apply(ParDo.of(new SplitReadFn(getAvroDataModel(), getProjectionSchema()))) + .apply( + ParDo.of( + new SplitReadFn<>( + getAvroDataModel(), + getProjectionSchema(), + GenericRecordPassthroughFn.create()))) .setCoder(AvroCoder.of(coderSchema)); } return input - .apply(ParDo.of(new ReadFn(getAvroDataModel()))) + .apply(ParDo.of(new ReadFn<>(getAvroDataModel(), GenericRecordPassthroughFn.create()))) .setCoder(AvroCoder.of(getSchema())); } @DoFn.BoundedPerElement - static class SplitReadFn extends DoFn { + static class SplitReadFn extends DoFn { private Class modelClass; private static final Logger LOG = LoggerFactory.getLogger(SplitReadFn.class); private String requestSchemaString; // Default initial splitting the file into blocks of 64MB. Unit of SPLIT_LIMIT is byte. private static final long SPLIT_LIMIT = 64000000; - SplitReadFn(GenericData model, Schema requestSchema) { + private final SerializableFunction parseFn; + + SplitReadFn( + GenericData model, Schema requestSchema, SerializableFunction parseFn) { this.modelClass = model != null ? model.getClass() : null; this.requestSchemaString = requestSchema != null ? requestSchema.toString() : null; + this.parseFn = checkNotNull(parseFn, "GenericRecord parse function can't be null"); } ParquetFileReader getParquetFileReader(FileIO.ReadableFile file) throws Exception { @@ -388,7 +583,7 @@ ParquetFileReader getParquetFileReader(FileIO.ReadableFile file) throws Exceptio public void processElement( @Element FileIO.ReadableFile file, RestrictionTracker tracker, - OutputReceiver outputReceiver) + OutputReceiver outputReceiver) throws Exception { LOG.debug( "start " @@ -468,7 +663,7 @@ record = recordReader.read(); file.toString()); continue; } - outputReceiver.output(record); + outputReceiver.output(parseFn.apply(record)); } catch (RuntimeException e) { throw new ParquetDecodingException( @@ -618,12 +813,15 @@ public Progress getProgress() { } } - static class ReadFn extends DoFn { + static class ReadFn extends DoFn { private Class modelClass; - ReadFn(GenericData model) { + private final SerializableFunction parseFn; + + ReadFn(GenericData model, SerializableFunction parseFn) { this.modelClass = model != null ? model.getClass() : null; + this.parseFn = checkNotNull(parseFn, "GenericRecord parse function is null"); } @ProcessElement @@ -647,7 +845,7 @@ public void processElement(ProcessContext processContext) throws Exception { try (ParquetReader reader = builder.build()) { GenericRecord read; while ((read = reader.read()) != null) { - processContext.output(read); + processContext.output(parseFn.apply(read)); } } } @@ -838,6 +1036,25 @@ public void close() throws IOException { } } + /** + * Passthrough function to provide seamless backward compatibility to ParquetIO's functionality. + */ + @VisibleForTesting + static class GenericRecordPassthroughFn + implements SerializableFunction { + + private static final GenericRecordPassthroughFn singleton = new GenericRecordPassthroughFn(); + + static GenericRecordPassthroughFn create() { + return singleton; + } + + @Override + public GenericRecord apply(GenericRecord input) { + return input; + } + } + /** Disallow construction of utility class. */ private ParquetIO() {} } diff --git a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java index 684ff5f8be10..76831d13425d 100644 --- a/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java +++ b/sdks/java/io/parquet/src/test/java/org/apache/beam/sdk/io/parquet/ParquetIOTest.java @@ -17,26 +17,35 @@ */ package org.apache.beam.sdk.io.parquet; +import static java.util.stream.Collectors.toList; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.EncoderFactory; +import org.apache.avro.io.JsonEncoder; import org.apache.avro.reflect.ReflectData; import org.apache.beam.sdk.coders.AvroCoder; import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.parquet.ParquetIO.GenericRecordPassthroughFn; import org.apache.beam.sdk.io.range.OffsetRange; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.Values; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PCollection; @@ -136,8 +145,10 @@ public void testBlockTracker() throws Exception { @Test public void testSplitBlockWithLimit() { - ParquetIO.ReadFiles.SplitReadFn testFn = new ParquetIO.ReadFiles.SplitReadFn(null, null); - ArrayList blockList = new ArrayList(); + ParquetIO.ReadFiles.SplitReadFn testFn = + new ParquetIO.ReadFiles.SplitReadFn<>( + null, null, ParquetIO.GenericRecordPassthroughFn.create()); + ArrayList blockList = new ArrayList<>(); ArrayList rangeList; BlockMetaData testBlock = mock(BlockMetaData.class); when(testBlock.getTotalByteSize()).thenReturn((long) 60); @@ -194,6 +205,28 @@ public void testWriteAndReadWithSplit() { readPipeline.run().waitUntilFinish(); } + @Test + public void testWriteAndReadFilesAsJsonForWithSplitForUnknownSchema() { + List records = generateGenericRecords(1000); + + mainPipeline + .apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA))) + .apply( + FileIO.write() + .via(ParquetIO.sink(SCHEMA)) + .to(temporaryFolder.getRoot().getAbsolutePath())); + mainPipeline.run().waitUntilFinish(); + + PCollection readBackAsJsonWithSplit = + readPipeline.apply( + ParquetIO.parseGenericRecords(ParseGenericRecordAsJsonFn.create()) + .from(temporaryFolder.getRoot().getAbsolutePath() + "/*") + .withSplit()); + + PAssert.that(readBackAsJsonWithSplit).containsInAnyOrder(convertRecordsToJson(records)); + readPipeline.run().waitUntilFinish(); + } + @Test public void testWriteAndReadFiles() { List records = generateGenericRecords(1000); @@ -216,6 +249,43 @@ public void testWriteAndReadFiles() { mainPipeline.run().waitUntilFinish(); } + @Test + public void testReadFilesAsJsonForUnknownSchemaFiles() { + List records = generateGenericRecords(1000); + List expectedJsonRecords = convertRecordsToJson(records); + + PCollection writeThenRead = + mainPipeline + .apply(Create.of(records).withCoder(AvroCoder.of(SCHEMA))) + .apply( + FileIO.write() + .via(ParquetIO.sink(SCHEMA)) + .to(temporaryFolder.getRoot().getAbsolutePath())) + .getPerDestinationOutputFilenames() + .apply(Values.create()) + .apply(FileIO.matchAll()) + .apply(FileIO.readMatches()) + .apply(ParquetIO.parseFilesGenericRecords(ParseGenericRecordAsJsonFn.create())); + + assertEquals(1000, expectedJsonRecords.size()); + PAssert.that(writeThenRead).containsInAnyOrder(expectedJsonRecords); + + mainPipeline.run().waitUntilFinish(); + } + + @Test + public void testReadFilesUnknownSchemaFilesForGenericRecordThrowException() { + IllegalArgumentException illegalArgumentException = + assertThrows( + IllegalArgumentException.class, + () -> + ParquetIO.parseFilesGenericRecords(GenericRecordPassthroughFn.create()) + .expand(null)); + + assertEquals( + "Parse can't be used for reading as GenericRecord.", illegalArgumentException.getMessage()); + } + private List generateGenericRecords(long count) { ArrayList data = new ArrayList<>(); GenericRecordBuilder builder = new GenericRecordBuilder(SCHEMA); @@ -345,4 +415,32 @@ public void testWriteAndReadwithSplitUsingReflectDataSchemaWithDataModel() { PAssert.that(readBack).containsInAnyOrder(records); readPipeline.run().waitUntilFinish(); } + + /** Returns list of JSON representation of GenericRecords. */ + private static List convertRecordsToJson(List records) { + return records.stream().map(ParseGenericRecordAsJsonFn.create()::apply).collect(toList()); + } + + /** Sample Parse function that converts GenericRecord as JSON. for testing. */ + private static class ParseGenericRecordAsJsonFn + implements SerializableFunction { + + public static ParseGenericRecordAsJsonFn create() { + return new ParseGenericRecordAsJsonFn(); + } + + @Override + public String apply(GenericRecord input) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + try { + JsonEncoder jsonEncoder = EncoderFactory.get().jsonEncoder(input.getSchema(), baos, true); + new GenericDatumWriter(input.getSchema()).write(input, jsonEncoder); + jsonEncoder.flush(); + } catch (IOException ioException) { + throw new RuntimeException("error converting record to JSON", ioException); + } + return baos.toString(); + } + } }