diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8c34e47314db..64c4aabd4cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} @@ -67,6 +69,41 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("toDS should compare map with byte array keys correctly") { + // Choose the order of arrays in such way, that sorting keys of different maps by _.toString + // will not incidentally put equal keys together. + val arrays = (1 to 5).map(_ => Array[Byte](0.toByte, 0.toByte)).sortBy(_.toString).toArray + arrays(0)(1) = 1.toByte + arrays(1)(1) = 2.toByte + arrays(2)(1) = 2.toByte + arrays(3)(1) = 1.toByte + + val mapA = Map(arrays(0) -> "one", arrays(2) -> "two") + val subsetOfA = Map(arrays(0) -> "one") + val equalToA = Map(arrays(1) -> "two", arrays(3) -> "one") + val notEqualToA1 = Map(arrays(1) -> "two", arrays(3) -> "not one") + val notEqualToA2 = Map(arrays(1) -> "two", arrays(4) -> "one") + + // Comparing map with itself + checkDataset(Seq(mapA).toDS(), mapA) + + // Comparing map with equivalent map + checkDataset(Seq(equalToA).toDS(), mapA) + checkDataset(Seq(mapA).toDS(), equalToA) + + // Comparing map with it's subset + intercept[TestFailedException](checkDataset(Seq(subsetOfA).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), subsetOfA)) + + // Comparing map with another map differing by single value + intercept[TestFailedException](checkDataset(Seq(notEqualToA1).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA1)) + + // Comparing map with another map differing by single key + intercept[TestFailedException](checkDataset(Seq(notEqualToA2).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA2)) + } + test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() checkDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d83deb17a090..f8298c9da97e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -341,9 +341,9 @@ object QueryTest { case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Map[_, _], b: Map[_, _]) => - val entries1 = a.iterator.toSeq.sortBy(_.toString()) - val entries2 = b.iterator.toSeq.sortBy(_.toString()) - compare(entries1, entries2) + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) =>