@@ -106,11 +106,36 @@ class ComplexDataSuite extends SparkFunSuite {
106106 }
107107
108108 test(" SPARK-24659: GenericArrayData.equals should respect element type differences" ) {
109- // Spark SQL considers array<int> and array<long> to be incompatible,
110- // so an underlying implementation of array type should return false in this case.
111- val array1 = new GenericArrayData (Array [Int ](123 ))
112- val array2 = new GenericArrayData (Array [Long ](123L ))
109+ import scala .reflect .ClassTag
113110
114- assert(! array1.equals(array2))
111+ // Expected positive cases
112+ def arraysShouldEqual [T : ClassTag ](element : T * ): Unit = {
113+ val array1 = new GenericArrayData (Array [T ](element : _* ))
114+ val array2 = new GenericArrayData (Array [T ](element : _* ))
115+ assert(array1.equals(array2))
116+ }
117+ arraysShouldEqual(true , false ) // Boolean
118+ arraysShouldEqual(0 .toByte, 123 .toByte, (- 123 ).toByte) // Byte
119+ arraysShouldEqual(0 .toShort, 123 .toShort, (- 256 ).toShort) // Short
120+ arraysShouldEqual(0 , 123 , - 65536 ) // Int
121+ arraysShouldEqual(0L , 123L , - 65536L ) // Long
122+ arraysShouldEqual(0.0F , 123.0F , - 65536.0F ) // Float
123+ arraysShouldEqual(0.0 , 123.0 , - 65536.0 ) // Double
124+ arraysShouldEqual(Array [Byte ](123 .toByte), null ) // Binary (Array[Byte])
125+ arraysShouldEqual(UTF8String .fromString(" foo" ), null ) // String (UTF8String)
126+
127+ // Expected negative cases
128+ // Spark SQL considers cases like array<int> vs array<long> to be incompatible,
129+ // so an underlying implementation of array type should return false in such cases.
130+ def arraysShouldNotEqual [T : ClassTag , U : ClassTag ](element1 : T , element2 : U ): Unit = {
131+ val array1 = new GenericArrayData (Array [T ](element1))
132+ val array2 = new GenericArrayData (Array [U ](element2))
133+ assert(! array1.equals(array2))
134+ }
135+ arraysShouldNotEqual(true , 1 ) // Boolean <-> Int
136+ arraysShouldNotEqual(123 .toByte, 123 ) // Byte <-> Int
137+ arraysShouldNotEqual(123 .toByte, 123L ) // Byte <-> Long
138+ arraysShouldNotEqual(123 .toShort, 123 ) // Short <-> Int
139+ arraysShouldNotEqual(123 , 123L ) // Int <-> Long
115140 }
116141}
0 commit comments