diff --git a/hudi-spark/src/main/scala/org/apache/hudi/AvroConversionUtils.scala b/hudi-spark/src/main/scala/org/apache/hudi/AvroConversionUtils.scala index 70a135624ffd3..8ac95d78199cd 100644 --- a/hudi-spark/src/main/scala/org/apache/hudi/AvroConversionUtils.scala +++ b/hudi-spark/src/main/scala/org/apache/hudi/AvroConversionUtils.scala @@ -17,13 +17,15 @@ package org.apache.hudi +import org.apache.spark.SPARK_VERSION import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder, IndexedRecord} import org.apache.hudi.common.model.HoodieKey import org.apache.avro.Schema import org.apache.hudi.avro.HoodieAvroUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.avro.SchemaConverters -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} @@ -41,7 +43,7 @@ object AvroConversionUtils { // Use the Avro schema to derive the StructType which has the correct nullability information val dataType = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType] val encoder = RowEncoder.apply(dataType).resolveAndBind() - df.queryExecution.toRdd.map(encoder.fromRow) + df.queryExecution.toRdd.map[Row](internalRow => deserializeRow(encoder, internalRow)) .mapPartitions { records => if (records.isEmpty) Iterator.empty else { @@ -96,4 +98,15 @@ object AvroConversionUtils { val name = HoodieAvroUtils.sanitizeName(tableName) (s"${name}_record", s"hoodie.${name}") } + + private def deserializeRow(encoder: ExpressionEncoder[Row], internalRow: InternalRow): Row = { + if (SPARK_VERSION.startsWith("2.")) { + val spark2method = encoder.getClass.getMethod("fromRow", classOf[InternalRow]) + spark2method.invoke(encoder, internalRow).asInstanceOf[Row] + } else { + val deserializer = encoder.getClass.getMethod("createDeserializer").invoke(encoder) + val aboveSpark2method = deserializer.getClass.getMethod("apply", classOf[InternalRow]) + aboveSpark2method.invoke(deserializer, internalRow).asInstanceOf[Row] + } + } } diff --git a/hudi-spark/src/test/java/HoodieJavaStreamingApp.java b/hudi-spark/src/test/java/HoodieJavaStreamingApp.java index 500189d1cf211..60e151a9d94b0 100644 --- a/hudi-spark/src/test/java/HoodieJavaStreamingApp.java +++ b/hudi-spark/src/test/java/HoodieJavaStreamingApp.java @@ -43,13 +43,12 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.streaming.DataStreamWriter; import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.streaming.ProcessingTime; +import org.apache.spark.sql.streaming.Trigger; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; -import org.apache.spark.sql.streaming.StreamingQuery; import static org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings; @@ -363,8 +362,7 @@ public void stream(Dataset streamingInput, String operationType, String che .outputMode(OutputMode.Append()); updateHiveSyncConfig(writer); - StreamingQuery query = writer.trigger(new ProcessingTime(500)).start(tablePath); - query.awaitTermination(streamingDurationInMs); + writer.trigger(Trigger.ProcessingTime(500)).start(tablePath).awaitTermination(streamingDurationInMs); } /** diff --git a/hudi-utilities/src/main/java/org/apache/hudi/utilities/UtilHelpers.java b/hudi-utilities/src/main/java/org/apache/hudi/utilities/UtilHelpers.java index 05311964618d8..1422519ff23e5 100644 --- a/hudi-utilities/src/main/java/org/apache/hudi/utilities/UtilHelpers.java +++ b/hudi-utilities/src/main/java/org/apache/hudi/utilities/UtilHelpers.java @@ -47,7 +47,6 @@ import org.apache.hadoop.fs.Path; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; -import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -60,6 +59,7 @@ import org.apache.spark.sql.jdbc.JdbcDialect; import org.apache.spark.sql.jdbc.JdbcDialects; import org.apache.spark.sql.types.StructType; +import org.apache.spark.util.LongAccumulator; import java.io.BufferedReader; import java.io.IOException; @@ -254,7 +254,7 @@ public static HoodieWriteClient createHoodieClient(JavaSparkContext jsc, String } public static int handleErrors(JavaSparkContext jsc, String instantTime, JavaRDD writeResponse) { - Accumulator errors = jsc.accumulator(0); + LongAccumulator errors = jsc.sc().longAccumulator(); writeResponse.foreach(writeStatus -> { if (writeStatus.hasErrors()) { errors.add(1);