Skip to content

Commit 083673c

Browse files
committed
Enhance the unit test case to cover more positive and negative cases.
1 parent d91b44a commit 083673c

File tree

1 file changed

+30
-5
lines changed

1 file changed

+30
-5
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,36 @@ class ComplexDataSuite extends SparkFunSuite {
106106
}
107107

108108
test("SPARK-24659: GenericArrayData.equals should respect element type differences") {
109-
// Spark SQL considers array<int> and array<long> to be incompatible,
110-
// so an underlying implementation of array type should return false in this case.
111-
val array1 = new GenericArrayData(Array[Int](123))
112-
val array2 = new GenericArrayData(Array[Long](123L))
109+
import scala.reflect.ClassTag
113110

114-
assert(!array1.equals(array2))
111+
// Expected positive cases
112+
def arraysShouldEqual[T: ClassTag](element: T*): Unit = {
113+
val array1 = new GenericArrayData(Array[T](element: _*))
114+
val array2 = new GenericArrayData(Array[T](element: _*))
115+
assert(array1.equals(array2))
116+
}
117+
arraysShouldEqual(true, false) // Boolean
118+
arraysShouldEqual(0.toByte, 123.toByte, (-123).toByte) // Byte
119+
arraysShouldEqual(0.toShort, 123.toShort, (-256).toShort) // Short
120+
arraysShouldEqual(0, 123, -65536) // Int
121+
arraysShouldEqual(0L, 123L, -65536L) // Long
122+
arraysShouldEqual(0.0F, 123.0F, -65536.0F) // Float
123+
arraysShouldEqual(0.0, 123.0, -65536.0) // Double
124+
arraysShouldEqual(Array[Byte](123.toByte), null) // Binary (Array[Byte])
125+
arraysShouldEqual(UTF8String.fromString("foo"), null) // String (UTF8String)
126+
127+
// Expected negative cases
128+
// Spark SQL considers cases like array<int> vs array<long> to be incompatible,
129+
// so an underlying implementation of array type should return false in such cases.
130+
def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = {
131+
val array1 = new GenericArrayData(Array[T](element1))
132+
val array2 = new GenericArrayData(Array[U](element2))
133+
assert(!array1.equals(array2))
134+
}
135+
arraysShouldNotEqual(true, 1) // Boolean <-> Int
136+
arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int
137+
arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long
138+
arraysShouldNotEqual(123.toShort, 123) // Short <-> Int
139+
arraysShouldNotEqual(123, 123L) // Int <-> Long
115140
}
116141
}

0 commit comments

Comments
 (0)