Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/bot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ jobs:
FLINK_PROFILE: ${{ matrix.flinkProfile }}
if: ${{ !endsWith(env.SPARK_PROFILE, '2.4') }} # skip test spark 2.4 as it's covered by Azure CI
run:
mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE" -D"$FLINK_PROFILE" '-Dtest=org.apache.spark.sql.hudi.Test*' -pl hudi-spark-datasource/hudi-spark
mvn test -Punit-tests -D"$SCALA_PROFILE" -D"$SPARK_PROFILE" -D"$FLINK_PROFILE" '-Dtest=Test*' -pl hudi-spark-datasource/hudi-spark
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TestConvertFilterToCatalystExpression {
private def checkConvertFilter(filter: Filter, expectExpression: String): Unit = {
// [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in UnresolvedAttribute,
// AttributeReference and Alias don't quote qualified names properly
val removeQuotesIfNeed = if (expectExpression != null && HoodieSparkUtils.isSpark3_2) {
val removeQuotesIfNeed = if (expectExpression != null && HoodieSparkUtils.gteqSpark3_2) {
expectExpression.replace("`", "")
} else {
expectExpression
Expand All @@ -86,7 +86,7 @@ class TestConvertFilterToCatalystExpression {
private def checkConvertFilters(filters: Array[Filter], expectExpression: String): Unit = {
// [SPARK-25769][SPARK-34636][SPARK-34626][SQL] sql method in UnresolvedAttribute,
// AttributeReference and Alias don't quote qualified names properly
val removeQuotesIfNeed = if (expectExpression != null && HoodieSparkUtils.isSpark3_2) {
val removeQuotesIfNeed = if (expectExpression != null && HoodieSparkUtils.gteqSpark3_2) {
expectExpression.replace("`", "")
} else {
expectExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@ import java.util.TimeZone
* A serializer to serialize data in catalyst format to data in avro format.
*
* NOTE: This code is borrowed from Spark 3.2.1
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
*
* NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH THE MODIFICATION
* BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND PROPERLY ALL THE
* MODIFICATIONS.
*
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
Expand Down Expand Up @@ -211,11 +216,20 @@ private[sql] class AvroSerializer(rootCatalystType: DataType,
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))

////////////////////////////////////////////////////////////////////////////////////////////
// Following section is amended to the original (Spark's) implementation
// >>> BEGINS
////////////////////////////////////////////////////////////////////////////////////////////

case (st: StructType, UNION) =>
val unionConverter = newUnionConverter(st, avroType, catalystPath, avroPath)
val numFields = st.length
(getter, ordinal) => unionConverter(getter.getStruct(ordinal, numFields))

////////////////////////////////////////////////////////////////////////////////////////////
// <<< ENDS
////////////////////////////////////////////////////////////////////////////////////////////

case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
Expand Down Expand Up @@ -293,6 +307,11 @@ private[sql] class AvroSerializer(rootCatalystType: DataType,
result
}

////////////////////////////////////////////////////////////////////////////////////////////
// Following section is amended to the original (Spark's) implementation
// >>> BEGINS
////////////////////////////////////////////////////////////////////////////////////////////

private def newUnionConverter(catalystStruct: StructType,
avroUnion: Schema,
catalystPath: Seq[String],
Expand Down Expand Up @@ -337,6 +356,10 @@ private[sql] class AvroSerializer(rootCatalystType: DataType,
avroStruct.getTypes.size() - 1 == catalystStruct.length) || avroStruct.getTypes.size() == catalystStruct.length
}

////////////////////////////////////////////////////////////////////////////////////////////
// <<< ENDS
////////////////////////////////////////////////////////////////////////////////////////////

/**
* Resolve a possibly nullable Avro Type.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,15 @@ import java.util.TimeZone
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
private[sql] class AvroDeserializer(
rootAvroType: Schema,
rootCatalystType: DataType,
positionalFieldMatch: Boolean,
datetimeRebaseSpec: RebaseSpec,
filters: StructFilters) {

def this(
rootAvroType: Schema,
rootCatalystType: DataType,
datetimeRebaseMode: String) = {
private[sql] class AvroDeserializer(rootAvroType: Schema,
rootCatalystType: DataType,
positionalFieldMatch: Boolean,
datetimeRebaseSpec: RebaseSpec,
filters: StructFilters) {

def this(rootAvroType: Schema,
rootCatalystType: DataType,
datetimeRebaseMode: String) = {
this(
rootAvroType,
rootCatalystType,
Expand All @@ -69,11 +67,9 @@ private[sql] class AvroDeserializer(

private lazy val decimalConversions = new DecimalConversion()

private val dateRebaseFunc = createDateRebaseFuncInRead(
datetimeRebaseSpec.mode, "Avro")
private val dateRebaseFunc = createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro")

private val timestampRebaseFunc = createTimestampRebaseFuncInRead(
datetimeRebaseSpec, "Avro")
private val timestampRebaseFunc = createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro")

private val converter: Any => Option[Any] = try {
rootCatalystType match {
Expand Down Expand Up @@ -112,11 +108,10 @@ private[sql] class AvroDeserializer(
* Creates a writer to write avro values to Catalyst values at the given ordinal with the given
* updater.
*/
private def newWriter(
avroType: Schema,
catalystType: DataType,
avroPath: Seq[String],
catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
private def newWriter(avroType: Schema,
catalystType: DataType,
avroPath: Seq[String],
catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = {
val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " +
s"SQL ${toFieldStr(catalystPath)} because "
val incompatibleMsg = errorPrefix +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed}
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, createTimestampRebaseFuncInWrite}
import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
Expand All @@ -44,17 +45,20 @@ import java.util.TimeZone
* A serializer to serialize data in catalyst format to data in avro format.
*
* NOTE: This code is borrowed from Spark 3.3.0
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
* This code is borrowed, so that we can better control compatibility w/in Spark minor
* branches (3.2.x, 3.1.x, etc)
*
* NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH THE MODIFICATION
* BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND PROPERLY ALL THE
* MODIFICATIONS.
*
* PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY
*/
private[sql] class AvroSerializer(
rootCatalystType: DataType,
rootAvroType: Schema,
nullable: Boolean,
positionalFieldMatch: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging {
private[sql] class AvroSerializer(rootCatalystType: DataType,
rootAvroType: Schema,
nullable: Boolean,
positionalFieldMatch: Boolean,
datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging {

def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = {
this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false,
Expand All @@ -65,10 +69,10 @@ private[sql] class AvroSerializer(
converter.apply(catalystData)
}

private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite(
private val dateRebaseFunc = createDateRebaseFuncInWrite(
datetimeRebaseMode, "Avro")

private val timestampRebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite(
private val timestampRebaseFunc = createTimestampRebaseFuncInWrite(
datetimeRebaseMode, "Avro")

private val converter: Any => Any = {
Expand Down Expand Up @@ -104,11 +108,10 @@ private[sql] class AvroSerializer(

private lazy val decimalConversions = new DecimalConversion()

private def newConverter(
catalystType: DataType,
avroType: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): Converter = {
private def newConverter(catalystType: DataType,
avroType: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): Converter = {
val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " +
s"to Avro ${toFieldStr(avroPath)} because "
(catalystType, avroType.getType) match {
Expand Down Expand Up @@ -162,6 +165,7 @@ private[sql] class AvroSerializer(
val data: Array[Byte] = getter.getBinary(ordinal)
if (data.length != size) {
def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else "byte"}"

throw new IncompatibleSchemaException(errorPrefix + len2str(data.length) +
" of binary data cannot be written into FIXED type with size of " + len2str(size))
}
Expand Down Expand Up @@ -223,6 +227,20 @@ private[sql] class AvroSerializer(
val numFields = st.length
(getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields))

////////////////////////////////////////////////////////////////////////////////////////////
// Following section is amended to the original (Spark's) implementation
// >>> BEGINS
////////////////////////////////////////////////////////////////////////////////////////////

case (st: StructType, UNION) =>
val unionConverter = newUnionConverter(st, avroType, catalystPath, avroPath)
val numFields = st.length
(getter, ordinal) => unionConverter(getter.getStruct(ordinal, numFields))

////////////////////////////////////////////////////////////////////////////////////////////
// <<< ENDS
////////////////////////////////////////////////////////////////////////////////////////////

case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType =>
val valueConverter = newConverter(
vt, resolveNullableType(avroType.getValueType, valueContainsNull),
Expand Down Expand Up @@ -257,11 +275,10 @@ private[sql] class AvroSerializer(
}
}

private def newStructConverter(
catalystStruct: StructType,
avroStruct: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): InternalRow => Record = {
private def newStructConverter(catalystStruct: StructType,
avroStruct: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): InternalRow => Record = {

val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch)
Expand Down Expand Up @@ -292,6 +309,60 @@ private[sql] class AvroSerializer(
result
}

////////////////////////////////////////////////////////////////////////////////////////////
// Following section is amended to the original (Spark's) implementation
// >>> BEGINS
////////////////////////////////////////////////////////////////////////////////////////////

private def newUnionConverter(catalystStruct: StructType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add UTs to cover this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're UT that cover this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you referring to AvroSerializer being used throughout the spark tests? but i don't see a specific UT for this union converter logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's TestAvroSerDe

avroUnion: Schema,
catalystPath: Seq[String],
avroPath: Seq[String]): InternalRow => Any = {
if (avroUnion.getType != UNION || !canMapUnion(catalystStruct, avroUnion)) {
throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " +
s"Avro type $avroUnion.")
}
val nullable = avroUnion.getTypes.size() > 0 && avroUnion.getTypes.get(0).getType == Type.NULL
val avroInnerTypes = if (nullable) {
avroUnion.getTypes.asScala.tail
} else {
avroUnion.getTypes.asScala
}
val fieldConverters = catalystStruct.zip(avroInnerTypes).map {
case (f1, f2) => newConverter(f1.dataType, f2, catalystPath, avroPath)
}
val numFields = catalystStruct.length
(row: InternalRow) =>
var i = 0
var result: Any = null
while (i < numFields) {
if (!row.isNullAt(i)) {
if (result != null) {
throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " +
s"Avro union $avroUnion. Record has more than one optional values set")
}
result = fieldConverters(i).apply(row, i)
}
i += 1
}
if (!nullable && result == null) {
throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " +
s"Avro union $avroUnion. Record has no values set, while should have exactly one")
}
result
}

private def canMapUnion(catalystStruct: StructType, avroStruct: Schema): Boolean = {
(avroStruct.getTypes.size() > 0 &&
avroStruct.getTypes.get(0).getType == Type.NULL &&
avroStruct.getTypes.size() - 1 == catalystStruct.length) || avroStruct.getTypes.size() == catalystStruct.length
}

////////////////////////////////////////////////////////////////////////////////////////////
// <<< ENDS
////////////////////////////////////////////////////////////////////////////////////////////


/**
* Resolve a possibly nullable Avro Type.
*
Expand Down Expand Up @@ -319,12 +390,12 @@ private[sql] class AvroSerializer(
if (avroType.getType == Type.UNION) {
val fields = avroType.getTypes.asScala
val actualType = fields.filter(_.getType != Type.NULL)
if (fields.length != 2 || actualType.length != 1) {
throw new UnsupportedAvroTypeException(
s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " +
"type is supported")
if (fields.length == 2 && actualType.length == 1) {
(true, actualType.head)
} else {
// This is just a normal union, not used to designate nullability
(false, avroType)
}
(true, actualType.head)
} else {
(false, avroType)
}
Expand Down