diff --git a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowDeserializer.java b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowSerDe.java similarity index 91% rename from hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowDeserializer.java rename to hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowSerDe.java index 66b8b78b56920..dce2d2fb62f1f 100644 --- a/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowDeserializer.java +++ b/hudi-client/hudi-spark-client/src/main/java/org/apache/hudi/client/utils/SparkRowSerDe.java @@ -23,6 +23,8 @@ import java.io.Serializable; -public interface SparkRowDeserializer extends Serializable { +public interface SparkRowSerDe extends Serializable { Row deserializeRow(InternalRow internalRow); + + InternalRow serializeRow(Row row); } diff --git a/hudi-spark-datasource/hudi-spark/pom.xml b/hudi-spark-datasource/hudi-spark/pom.xml index 4f56c7e9391b7..d2c9485020c53 100644 --- a/hudi-spark-datasource/hudi-spark/pom.xml +++ b/hudi-spark-datasource/hudi-spark/pom.xml @@ -266,6 +266,25 @@ spark-sql_${scala.binary.version} + + org.apache.spark + spark-sql_${scala.binary.version} + tests + test + + + org.apache.spark + spark-core_${scala.binary.version} + tests + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + tests + test + + org.apache.spark diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala index d26390d635a6e..5cceb7878bcd5 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/DefaultSource.scala @@ -22,12 +22,13 @@ import org.apache.hudi.DataSourceReadOptions._ import org.apache.hudi.common.model.{HoodieRecord, HoodieTableType} import org.apache.hudi.DataSourceWriteOptions.{BOOTSTRAP_OPERATION_OPT_VAL, OPERATION_OPT_KEY} import org.apache.hudi.common.fs.FSUtils -import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver} import org.apache.hudi.exception.HoodieException import org.apache.hudi.hadoop.HoodieROTablePathFilter import org.apache.log4j.LogManager import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.Sink +import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.hudi.streaming.HoodieStreamSource import org.apache.spark.sql.sources._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -44,6 +45,7 @@ class DefaultSource extends RelationProvider with CreatableRelationProvider with DataSourceRegister with StreamSinkProvider + with StreamSourceProvider with Serializable { private val log = LogManager.getLogger(classOf[DefaultSource]) @@ -181,4 +183,35 @@ class DefaultSource extends RelationProvider .resolveRelation() } } + + override def sourceSchema(sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + val path = parameters.get("path") + if (path.isEmpty || path.get == null) { + throw new HoodieException(s"'path' must be specified.") + } + val metaClient = new HoodieTableMetaClient( + sqlContext.sparkSession.sessionState.newHadoopConf(), path.get) + val schemaResolver = new TableSchemaResolver(metaClient) + val sqlSchema = + try { + val avroSchema = schemaResolver.getTableAvroSchema + AvroConversionUtils.convertAvroSchemaToStructType(avroSchema) + } catch { + case _: Exception => + require(schema.isDefined, "Fail to resolve source schema") + schema.get + } + (shortName(), sqlSchema) + } + + override def createSource(sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + new HoodieStreamSource(sqlContext, metadataPath, schema, parameters) + } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala index 02880f22b93fc..bd55930d1a41d 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala @@ -21,7 +21,7 @@ package org.apache.hudi import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hudi.client.utils.SparkRowDeserializer +import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.common.model.HoodieRecord import org.apache.spark.SPARK_VERSION import org.apache.spark.rdd.RDD @@ -99,7 +99,7 @@ object HoodieSparkUtils { // Use the Avro schema to derive the StructType which has the correct nullability information val dataType = SchemaConverters.toSqlType(avroSchema).dataType.asInstanceOf[StructType] val encoder = RowEncoder.apply(dataType).resolveAndBind() - val deserializer = HoodieSparkUtils.createDeserializer(encoder) + val deserializer = HoodieSparkUtils.createRowSerDe(encoder) df.queryExecution.toRdd.map(row => deserializer.deserializeRow(row)) .mapPartitions { records => if (records.isEmpty) Iterator.empty @@ -110,12 +110,12 @@ object HoodieSparkUtils { } } - def createDeserializer(encoder: ExpressionEncoder[Row]): SparkRowDeserializer = { - // TODO remove Spark2RowDeserializer if Spark 2.x support is dropped + def createRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = { + // TODO remove Spark2RowSerDe if Spark 2.x support is dropped if (SPARK_VERSION.startsWith("2.")) { - new Spark2RowDeserializer(encoder) + new Spark2RowSerDe(encoder) } else { - new Spark3RowDeserializer(encoder) + new Spark3RowSerDe(encoder) } } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieSourceOffset.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieSourceOffset.scala new file mode 100644 index 0000000000000..03e651ed83ab4 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieSourceOffset.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.streaming + +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.fasterxml.jackson.module.scala.experimental.ScalaObjectMapper +import org.apache.hudi.common.table.timeline.HoodieTimeline +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} + +case class HoodieSourceOffset(commitTime: String) extends Offset { + + override def json(): String = { + HoodieSourceOffset.toJson(this) + } + + override def equals(obj: Any): Boolean = { + obj match { + case HoodieSourceOffset(otherCommitTime) => + otherCommitTime == commitTime + case _=> false + } + } + + override def hashCode(): Int = { + commitTime.hashCode + } +} + + +object HoodieSourceOffset { + val mapper = new ObjectMapper with ScalaObjectMapper + mapper.setSerializationInclusion(Include.NON_ABSENT) + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + mapper.registerModule(DefaultScalaModule) + + def toJson(offset: HoodieSourceOffset): String = { + mapper.writeValueAsString(offset) + } + + def fromJson(json: String): HoodieSourceOffset = { + mapper.readValue[HoodieSourceOffset](json) + } + + def apply(offset: Offset): HoodieSourceOffset = { + offset match { + case SerializedOffset(json) => fromJson(json) + case o: HoodieSourceOffset => o + } + } + + val INIT_OFFSET = HoodieSourceOffset(HoodieTimeline.INIT_INSTANT_TS) +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieStreamSource.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieStreamSource.scala new file mode 100644 index 0000000000000..c17598a58d350 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/streaming/HoodieStreamSource.scala @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.streaming + +import java.io.{BufferedWriter, InputStream, OutputStream, OutputStreamWriter} +import java.nio.charset.StandardCharsets +import java.util.Date + +import org.apache.hadoop.fs.Path +import org.apache.hudi.{DataSourceReadOptions, HoodieSparkUtils, IncrementalRelation, MergeOnReadIncrementalRelation} +import org.apache.hudi.common.model.HoodieTableType +import org.apache.hudi.common.table.timeline.HoodieActiveTimeline +import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver} +import org.apache.hudi.common.util.{FileIOUtils, TablePathUtils} +import org.apache.spark.sql.hudi.streaming.HoodieStreamSource.VERSION +import org.apache.spark.sql.hudi.streaming.HoodieSourceOffset.INIT_OFFSET +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.avro.SchemaConverters +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, Offset, Source} +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * The Struct Stream Source for Hudi to consume the data by streaming job. + * @param sqlContext + * @param metadataPath + * @param schemaOption + * @param parameters + */ +class HoodieStreamSource( + sqlContext: SQLContext, + metadataPath: String, + schemaOption: Option[StructType], + parameters: Map[String, String]) + extends Source with Logging with Serializable { + + @transient private val hadoopConf = sqlContext.sparkSession.sessionState.newHadoopConf() + private lazy val tablePath: Path = { + val path = new Path(parameters.getOrElse("path", "Missing 'path' option")) + val fs = path.getFileSystem(hadoopConf) + TablePathUtils.getTablePath(fs, path).get() + } + private lazy val metaClient = new HoodieTableMetaClient(hadoopConf, tablePath.toString) + private lazy val tableType = metaClient.getTableType + + @transient private var lastOffset: HoodieSourceOffset = _ + @transient private lazy val initialOffsets = { + val metadataLog = + new HDFSMetadataLog[HoodieSourceOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: HoodieSourceOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush() + } + + /** + * Deserialize the init offset from the metadata file. + * The format in the metadata file is like this: + * ---------------------------------------------- + * v1 -- The version info in the first line + * offsetJson -- The json string of HoodieSourceOffset in the rest of the file + * ----------------------------------------------- + * @param in + * @return + */ + override def deserialize(in: InputStream): HoodieSourceOffset = { + val content = FileIOUtils.readAsUTFString(in) + // Get version from the first line + val firstLineEnd = content.indexOf("\n") + if (firstLineEnd > 0) { + val version = getVersion(content.substring(0, firstLineEnd)) + if (version > VERSION) { + throw new IllegalStateException(s"UnSupportVersion: max support version is: $VERSION" + + s" current version is: $version") + } + // Get offset from the rest line in the file + HoodieSourceOffset.fromJson(content.substring(firstLineEnd + 1)) + } else { + throw new IllegalStateException(s"Bad metadata format, failed to find the version line.") + } + } + } + metadataLog.get(0).getOrElse { + metadataLog.add(0, INIT_OFFSET) + INIT_OFFSET + } + } + + private def getVersion(versionLine: String): Int = { + if (versionLine.startsWith("v")) { + versionLine.substring(1).toInt + } else { + throw new IllegalStateException(s"Illegal version line: $versionLine " + + s"in the streaming metadata path") + } + } + + override def schema: StructType = { + schemaOption.getOrElse { + val schemaUtil = new TableSchemaResolver(metaClient) + SchemaConverters.toSqlType(schemaUtil.getTableAvroSchema) + .dataType.asInstanceOf[StructType] + } + } + + /** + * Get the latest offset from the hoodie table. + * @return + */ + override def getOffset: Option[Offset] = { + metaClient.reloadActiveTimeline() + val activeInstants = metaClient.getActiveTimeline.getCommitsTimeline.filterCompletedInstants + if (!activeInstants.empty()) { + val currentLatestCommitTime = activeInstants.lastInstant().get().getTimestamp + if (lastOffset == null || currentLatestCommitTime > lastOffset.commitTime) { + lastOffset = HoodieSourceOffset(currentLatestCommitTime) + } + } else { // if there are no active commits, use the init offset + lastOffset = initialOffsets + } + Some(lastOffset) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + initialOffsets + + val startOffset = start.map(HoodieSourceOffset(_)) + .getOrElse(initialOffsets) + val endOffset = HoodieSourceOffset(end) + + if (startOffset == endOffset) { + sqlContext.internalCreateDataFrame( + sqlContext.sparkContext.emptyRDD[InternalRow].setName("empty"), schema, isStreaming = true) + } else { + // Consume the data between (startCommitTime, endCommitTime] + val incParams = parameters ++ Map( + DataSourceReadOptions.BEGIN_INSTANTTIME_OPT_KEY -> startCommitTime(startOffset), + DataSourceReadOptions.END_INSTANTTIME_OPT_KEY -> endOffset.commitTime + ) + + val rdd = tableType match { + case HoodieTableType.COPY_ON_WRITE => + val serDe = HoodieSparkUtils.createRowSerDe(RowEncoder(schema)) + new IncrementalRelation(sqlContext, incParams, schema, metaClient) + .buildScan() + .map(serDe.serializeRow) + case HoodieTableType.MERGE_ON_READ => + val requiredColumns = schema.fields.map(_.name) + new MergeOnReadIncrementalRelation(sqlContext, incParams, schema, metaClient) + .buildScan(requiredColumns, Array.empty[Filter]) + .asInstanceOf[RDD[InternalRow]] + case _ => throw new IllegalArgumentException(s"UnSupport tableType: $tableType") + } + sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true) + } + } + + private def startCommitTime(startOffset: HoodieSourceOffset): String = { + startOffset match { + case INIT_OFFSET => startOffset.commitTime + case HoodieSourceOffset(commitTime) => + val time = HoodieActiveTimeline.COMMIT_FORMATTER.parse(commitTime).getTime + // As we consume the data between (start, end], start is not included, + // so we +1s to the start commit time here. + HoodieActiveTimeline.COMMIT_FORMATTER.format(new Date(time + 1000)) + case _=> throw new IllegalStateException("UnKnow offset type.") + } + } + + override def stop(): Unit = { + + } +} + +object HoodieStreamSource { + val VERSION = 1 +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStreamingSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStreamingSource.scala new file mode 100644 index 0000000000000..a98152a5e20bc --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStreamingSource.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.functional + +import org.apache.hudi.DataSourceWriteOptions +import org.apache.hudi.DataSourceWriteOptions.{PRECOMBINE_FIELD_OPT_KEY, RECORDKEY_FIELD_OPT_KEY} +import org.apache.hudi.common.model.HoodieTableType.{COPY_ON_WRITE, MERGE_ON_READ} +import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.config.HoodieWriteConfig.{DELETE_PARALLELISM, INSERT_PARALLELISM, TABLE_NAME, UPSERT_PARALLELISM} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{Row, SaveMode} + +class TestStreamingSource extends StreamTest { + + import testImplicits._ + private val commonOptions = Map( + RECORDKEY_FIELD_OPT_KEY -> "id", + PRECOMBINE_FIELD_OPT_KEY -> "ts", + INSERT_PARALLELISM -> "4", + UPSERT_PARALLELISM -> "4", + DELETE_PARALLELISM -> "4" + ) + private val columns = Seq("id", "name", "price", "ts") + + override protected def sparkConf = { + super.sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + } + + test("test cow stream source") { + withTempDir { inputDir => + val tablePath = s"${inputDir.getCanonicalPath}/test_cow_stream" + HoodieTableMetaClient.initTableType(spark.sessionState.newHadoopConf(), tablePath, + COPY_ON_WRITE, getTableName(tablePath), DataSourceWriteOptions.DEFAULT_PAYLOAD_OPT_VAL) + + addData(tablePath, Seq(("1", "a1", "10", "000"))) + val df = spark.readStream + .format("org.apache.hudi") + .load(tablePath) + .select("id", "name", "price", "ts") + + testStream(df)( + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows(Seq(Row("1", "a1", "10", "000")), lastOnly = true, isSorted = false), + StopStream, + + addDataToQuery(tablePath, Seq(("1", "a1", "12", "000"))), + StartStream(), + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows(Seq(Row("1", "a1", "12", "000")), lastOnly = true, isSorted = false), + + addDataToQuery(tablePath, + Seq(("2", "a2", "12", "000"), + ("3", "a3", "12", "000"), + ("4", "a4", "12", "000"))), + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows( + Seq(Row("2", "a2", "12", "000"), + Row("3", "a3", "12", "000"), + Row("4", "a4", "12", "000")), + lastOnly = true, isSorted = false), + StopStream, + + addDataToQuery(tablePath, Seq(("5", "a5", "12", "000"))), + addDataToQuery(tablePath, Seq(("6", "a6", "12", "000"))), + addDataToQuery(tablePath, Seq(("5", "a5", "15", "000"))), + StartStream(), + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows( + Seq(Row("6", "a6", "12", "000"), + Row("5", "a5", "15", "000")), + lastOnly = true, isSorted = false) + ) + } + } + + test("test mor stream source") { + withTempDir { inputDir => + val tablePath = s"${inputDir.getCanonicalPath}/test_mor_stream" + HoodieTableMetaClient.initTableType(spark.sessionState.newHadoopConf(), tablePath, + MERGE_ON_READ, getTableName(tablePath), DataSourceWriteOptions.DEFAULT_PAYLOAD_OPT_VAL) + + addData(tablePath, Seq(("1", "a1", "10", "000"))) + val df = spark.readStream + .format("org.apache.hudi") + .load(tablePath) + .select("id", "name", "price", "ts") + + testStream(df)( + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows(Seq(Row("1", "a1", "10", "000")), lastOnly = true, isSorted = false), + StopStream, + + addDataToQuery(tablePath, + Seq(("2", "a2", "12", "000"), + ("3", "a3", "12", "000"), + ("2", "a2", "10", "001"))), + StartStream(), + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows( + Seq(Row("3", "a3", "12", "000"), + Row("2", "a2", "10", "001")), + lastOnly = true, isSorted = false), + StopStream, + + addDataToQuery(tablePath, Seq(("5", "a5", "12", "000"))), + addDataToQuery(tablePath, Seq(("6", "a6", "12", "000"))), + StartStream(), + AssertOnQuery {q => q.processAllAvailable(); true }, + CheckAnswerRows( + Seq(Row("5", "a5", "12", "000"), + Row("6", "a6", "12", "000")), + lastOnly = true, isSorted = false) + ) + } + } + + private def addData(inputPath: String, rows: Seq[(String, String, String, String)]): Unit = { + rows.toDF(columns: _*) + .write + .format("org.apache.hudi") + .options(commonOptions) + .option(TABLE_NAME, getTableName(inputPath)) + .mode(SaveMode.Append) + .save(inputPath) + } + + private def addDataToQuery(inputPath: String, + rows: Seq[(String, String, String, String)]): AssertOnQuery = { + AssertOnQuery { _=> + addData(inputPath, rows) + true + } + } + + private def getTableName(inputPath: String): String = { + val start = inputPath.lastIndexOf('/') + inputPath.substring(start + 1) + } +} diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowDeserializer.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowSerDe.scala similarity index 83% rename from hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowDeserializer.scala rename to hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowSerDe.scala index 84fe4c3e8b28b..ca04470558bd7 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/hudi/Spark2RowSerDe.scala @@ -17,14 +17,17 @@ package org.apache.hudi -import org.apache.hudi.client.utils.SparkRowDeserializer - +import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -class Spark2RowDeserializer(val encoder: ExpressionEncoder[Row]) extends SparkRowDeserializer { +class Spark2RowSerDe(val encoder: ExpressionEncoder[Row]) extends SparkRowSerDe { def deserializeRow(internalRow: InternalRow): Row = { encoder.fromRow(internalRow) } + + override def serializeRow(row: Row): InternalRow = { + encoder.toRow(row) + } } diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowDeserializer.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowSerDe.scala similarity index 79% rename from hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowDeserializer.scala rename to hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowSerDe.scala index a0606553ff275..e3d809f0067c7 100644 --- a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowDeserializer.scala +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/hudi/Spark3RowSerDe.scala @@ -17,17 +17,21 @@ package org.apache.hudi -import org.apache.hudi.client.utils.SparkRowDeserializer - +import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -class Spark3RowDeserializer(val encoder: ExpressionEncoder[Row]) extends SparkRowDeserializer { +class Spark3RowSerDe(val encoder: ExpressionEncoder[Row]) extends SparkRowSerDe { private val deserializer: ExpressionEncoder.Deserializer[Row] = encoder.createDeserializer() + private val serializer: ExpressionEncoder.Serializer[Row] = encoder.createSerializer() def deserializeRow(internalRow: InternalRow): Row = { deserializer.apply(internalRow) } + + override def serializeRow(row: Row): InternalRow = { + serializer.apply(row) + } } diff --git a/pom.xml b/pom.xml index 91780dae894f9..cd66d65141aaf 100644 --- a/pom.xml +++ b/pom.xml @@ -527,6 +527,27 @@ ${spark.version} provided + + org.apache.spark + spark-sql_${scala.binary.version} + tests + ${spark.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + tests + ${spark.version} + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + tests + ${spark.version} + test +