From 95aad6ec5813f6789eb57d9466e3b5f30115258d Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 11:07:18 +0100 Subject: [PATCH 1/7] Added map normalization on creation --- .../catalyst/util/ArrayBasedMapBuilder.scala | 24 +++++++++++-- .../util/ArrayBasedMapBuilderSuite.scala | 36 ++++++++++++++++++- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index d358c92dd62c..d938926d6654 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -21,6 +21,8 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -52,18 +54,34 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + def normalize(value: Any, dataType: DataType): Any = dataType match { + case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER(value) + case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER(value) + case ArrayType(dt, _) => + new GenericArrayData(value.asInstanceOf[GenericArrayData].array.map { element => + normalize(element, dt) + }) + case StructType(sf) => + new GenericInternalRow( + value.asInstanceOf[GenericInternalRow].values.zipWithIndex.map { element => + normalize(element._1, sf(element._2).dataType) + }) + case _ => value + } + def put(key: Any, value: Any): Unit = { if (key == null) { throw QueryExecutionErrors.nullAsMapKeyNotAllowedError() } - val index = keyToIndex.getOrDefault(key, -1) + val keyNormalized = normalize(key, keyType) + val index = keyToIndex.getOrDefault(keyNormalized, -1) if (index == -1) { if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.exceedMapSizeLimitError(size) } - keyToIndex.put(key, values.length) - keys.append(key) + keyToIndex.put(keyNormalized, values.length) + keys.append(keyNormalized) values.append(value) } else { if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 5811f4cd4c85..d0aa478a9357 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DoubleType, IntegerType, StructType} import org.apache.spark.unsafe.Platform class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { @@ -60,6 +60,40 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { ) } + test("apply key normalization when creating") { + val builderDouble = new ArrayBasedMapBuilder(DoubleType, IntegerType) + builderDouble.put(-0.0, 1) + checkError( + exception = intercept[SparkRuntimeException](builderDouble.put(0.0, 2)), + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "0.0", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") + ) + + val builderArray = new ArrayBasedMapBuilder(ArrayType(DoubleType), IntegerType) + builderArray.put(new GenericArrayData(Seq(-0.0)), 1) + checkError( + exception = intercept[SparkRuntimeException]( + builderArray.put(new GenericArrayData(Seq(0.0)), 1)), + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "[0.0]", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") + ) + + val builderStruct = new ArrayBasedMapBuilder(new StructType().add("i", "double"), IntegerType) + builderStruct.put(InternalRow(-0.0), 1) + // By default duplicated map key fails the query. + checkError( + exception = intercept[SparkRuntimeException](builderStruct.put(InternalRow(0.0), 3)), + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "[0.0]", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") + ) + } + test("remove duplicated keys with last wins policy") { withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) From 92315474d92694629ff51c54620a03e3d15a72d5 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 12:06:30 +0100 Subject: [PATCH 2/7] Added a check if a key needs to be normalized --- .../sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 2 +- .../apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index f946fe76bde4..0b8edcaee75e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -94,7 +94,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ => needNormalize(expr.dataType) } - private def needNormalize(dt: DataType): Boolean = dt match { + def needNormalize(dt: DataType): Boolean = dt match { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index d938926d6654..fd475980a917 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -54,6 +54,8 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + private lazy val keyNeedNormalize = NormalizeFloatingNumbers.needNormalize(keyType) + def normalize(value: Any, dataType: DataType): Any = dataType match { case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER(value) case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER(value) @@ -74,7 +76,7 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw QueryExecutionErrors.nullAsMapKeyNotAllowedError() } - val keyNormalized = normalize(key, keyType) + val keyNormalized = if (keyNeedNormalize) normalize(key, keyType) else key val index = keyToIndex.getOrDefault(keyNormalized, -1) if (index == -1) { if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { From 9649c5d97d8847c0cf63949a9998521113ba31ea Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 12:19:22 +0100 Subject: [PATCH 3/7] new test with struct of arrays --- .../catalyst/util/ArrayBasedMapBuilderSuite.scala | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index d0aa478a9357..b6d93165c3d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -92,6 +92,19 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { "key" -> "[0.0]", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") ) + + val builderStructWithArray = new ArrayBasedMapBuilder( + new StructType().add("array", ArrayType(DoubleType) ), IntegerType) + builderStructWithArray.put(InternalRow(new GenericArrayData(Seq(-0.0))), 1) + checkError( + exception = intercept[SparkRuntimeException]( + builderStructWithArray.put(InternalRow(new GenericArrayData(Seq(0.0))), 1)), + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "[[0.0]]", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") + ) + } test("remove duplicated keys with last wins policy") { From 5f008daed7a090cb81e523d638956f939ab50b92 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 12:25:03 +0100 Subject: [PATCH 4/7] Added test for successful map normalization --- .../sql/catalyst/util/ArrayBasedMapBuilderSuite.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index b6d93165c3d9..a4eaa88e923c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -104,7 +104,14 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { "key" -> "[[0.0]]", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") ) + } + test("successful map normalization on build") { + val builder = new ArrayBasedMapBuilder(DoubleType, IntegerType) + builder.put(-0.0, 1) + val map = builder.build() + assert(map.numElements() == 1) + assert(ArrayBasedMapData.toScalaMap(map) == Map(0.0 -> 1)) } test("remove duplicated keys with last wins policy") { From cd4bb3d1c6ceeb8e3a89abee0dc575431ab1700c Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 16:39:01 +0100 Subject: [PATCH 5/7] Changed normalize function to return a lambda function based on data type --- .../catalyst/util/ArrayBasedMapBuilder.scala | 24 ++++--------- .../util/ArrayBasedMapBuilderSuite.scala | 34 ------------------- 2 files changed, 7 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index fd475980a917..e9f5116e1f28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -54,21 +53,12 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) - private lazy val keyNeedNormalize = NormalizeFloatingNumbers.needNormalize(keyType) - - def normalize(value: Any, dataType: DataType): Any = dataType match { - case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER(value) - case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER(value) - case ArrayType(dt, _) => - new GenericArrayData(value.asInstanceOf[GenericArrayData].array.map { element => - normalize(element, dt) - }) - case StructType(sf) => - new GenericInternalRow( - value.asInstanceOf[GenericInternalRow].values.zipWithIndex.map { element => - normalize(element._1, sf(element._2).dataType) - }) - case _ => value + private lazy val keyNeedNormalize = + keyType.isInstanceOf[FloatType] || keyType.isInstanceOf[DoubleType] + + def normalize(dataType: DataType): Any => Any = dataType match { + case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER + case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER } def put(key: Any, value: Any): Unit = { @@ -76,7 +66,7 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw QueryExecutionErrors.nullAsMapKeyNotAllowedError() } - val keyNormalized = if (keyNeedNormalize) normalize(key, keyType) else key + val keyNormalized = if (keyNeedNormalize) normalize(keyType)(key) else key val index = keyToIndex.getOrDefault(keyNormalized, -1) if (index == -1) { if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index a4eaa88e923c..3c8c49ee7fec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -70,40 +70,6 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { "key" -> "0.0", "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") ) - - val builderArray = new ArrayBasedMapBuilder(ArrayType(DoubleType), IntegerType) - builderArray.put(new GenericArrayData(Seq(-0.0)), 1) - checkError( - exception = intercept[SparkRuntimeException]( - builderArray.put(new GenericArrayData(Seq(0.0)), 1)), - errorClass = "DUPLICATED_MAP_KEY", - parameters = Map( - "key" -> "[0.0]", - "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") - ) - - val builderStruct = new ArrayBasedMapBuilder(new StructType().add("i", "double"), IntegerType) - builderStruct.put(InternalRow(-0.0), 1) - // By default duplicated map key fails the query. - checkError( - exception = intercept[SparkRuntimeException](builderStruct.put(InternalRow(0.0), 3)), - errorClass = "DUPLICATED_MAP_KEY", - parameters = Map( - "key" -> "[0.0]", - "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") - ) - - val builderStructWithArray = new ArrayBasedMapBuilder( - new StructType().add("array", ArrayType(DoubleType) ), IntegerType) - builderStructWithArray.put(InternalRow(new GenericArrayData(Seq(-0.0))), 1) - checkError( - exception = intercept[SparkRuntimeException]( - builderStructWithArray.put(InternalRow(new GenericArrayData(Seq(0.0))), 1)), - errorClass = "DUPLICATED_MAP_KEY", - parameters = Map( - "key" -> "[[0.0]]", - "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") - ) } test("successful map normalization on build") { From 15065a1fd2446bbc072c7a6eb7f30bed29a5b184 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Tue, 26 Mar 2024 16:40:16 +0100 Subject: [PATCH 6/7] reverted needNormalize to private scope --- .../spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 0b8edcaee75e..f946fe76bde4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -94,7 +94,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ => needNormalize(expr.dataType) } - def needNormalize(dt: DataType): Boolean = dt match { + private def needNormalize(dt: DataType): Boolean = dt match { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) From e9322d35ae1d26c489f50bdc8bea6356f8278d30 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 27 Mar 2024 09:35:42 +0100 Subject: [PATCH 7/7] refactored normalizer function --- .../spark/sql/catalyst/util/ArrayBasedMapBuilder.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index e9f5116e1f28..d13c3c6026a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -53,12 +53,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) - private lazy val keyNeedNormalize = - keyType.isInstanceOf[FloatType] || keyType.isInstanceOf[DoubleType] - - def normalize(dataType: DataType): Any => Any = dataType match { + private lazy val keyNormalizer: Any => Any = keyType match { case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER + case _ => identity } def put(key: Any, value: Any): Unit = { @@ -66,7 +64,7 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw QueryExecutionErrors.nullAsMapKeyNotAllowedError() } - val keyNormalized = if (keyNeedNormalize) normalize(keyType)(key) else key + val keyNormalized = keyNormalizer(key) val index = keyToIndex.getOrDefault(keyNormalized, -1) if (index == -1) { if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {