diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6df8d66ee7f2c..f359fb98be3a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -857,6 +857,90 @@ class DatasetSuite extends QueryTest 1 -> "a", 2 -> "bc", 3 -> "d") } + test("cogroup with complex key types") { + // Test cogroup with nested structure as key using existing ClassData + val ds1 = Seq( + (ClassData("x", 1), "left1"), + (ClassData("x", 1), "left2"), + (ClassData("y", 2), "left3") + ).toDS() + + val ds2 = Seq( + (ClassData("x", 1), 100), + (ClassData("z", 3), 200) + ).toDS() + + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { + case (key, left, right) => + Iterator((key.a, key.b, left.size, right.size)) + } + + checkDatasetUnorderly( + cogrouped, + ("x", 1, 2, 1), // ClassData("x", 1): 2 left, 1 right + ("y", 2, 1, 0), // ClassData("y", 2): 1 left, 0 right + ("z", 3, 0, 1) // ClassData("z", 3): 0 left, 1 right + ) + } + + test("cogroup with null keys") { + // Test that null keys are handled correctly - rows with null keys should be grouped together. + val ds1 = Seq( + (Some(1), "a"), + (Some(1), "b"), + (None, "c"), + (None, "d"), + (Some(2), "e") + ).toDS() + val ds2 = Seq( + (Some(1), 10), + (None, 20), + (Some(3), 30) + ).toDS() + + val cogrouped = ds1.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { + case (key, left, right) => + Iterator((key, left.size, right.size)) + } + + checkDatasetUnorderly( + cogrouped, + (Some(1), 2, 1), // key=1: 2 left ("a","b"), 1 right (10) + (None, 2, 1), // key=null: 2 left ("c","d"), 1 right (20) + (Some(2), 1, 0), // key=2: 1 left ("e"), 0 right + (Some(3), 0, 1) // key=3: 0 left, 1 right (30) + ) + } + + test("cogroup with empty datasets") { + val ds1 = Seq(1 -> "a", 2 -> "b").toDS() + val ds2 = Seq(2 -> 100, 3 -> 200).toDS() + val emptyDs = spark.emptyDataset[(Int, String)] + val emptyDs2 = spark.emptyDataset[(Int, Long)] + + // Helper function to count elements from each side + def countElements[L, R](left: Iterator[L], right: Iterator[R]): (Int, Int) = + (left.size, right.size) + + // Empty left: all keys come from right, left iterator is always empty + val emptyLeftResult = emptyDs.groupByKey(_._1).cogroup(ds2.groupByKey(_._1)) { + case (key, left, right) => Iterator((key, countElements(left, right))) + }.collect().sortBy(_._1) + assert(emptyLeftResult === Array((2, (0, 1)), (3, (0, 1)))) + + // Empty right: all keys come from left, right iterator is always empty + val emptyRightResult = ds1.groupByKey(_._1).cogroup(emptyDs.groupByKey(_._1)) { + case (key, left, right) => Iterator((key, countElements(left, right))) + }.collect().sortBy(_._1) + assert(emptyRightResult === Array((1, (1, 0)), (2, (1, 0)))) + + // Both empty: result should be empty + val bothEmptyResult = emptyDs.groupByKey(_._1).cogroup(emptyDs2.groupByKey(_._1)) { + case (key, left, right) => Iterator((key, countElements(left, right))) + }.collect() + assert(bothEmptyResult.isEmpty) + } + test("cogroup with groupBy and sorted") { val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abc", 3 -> "ijk").toDS() val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDS()