diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8ca004d520c5..d7b3c057d927 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1508,8 +1508,10 @@ class SparkConnectPlanner(val session: SparkSession) { maxRecordsPerBatch, maxBatchSize, timeZoneId) - assert(batches.size == 1) - batches.next() + assert(batches.hasNext) + val bytes = batches.next() + assert(!batches.hasNext, s"remaining batches: ${batches.size}") + bytes } // To avoid explicit handling of the result on the client, we build the expected input diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 806fe6e2329b..dad303d3463e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -2876,6 +2876,12 @@ def test_unsupported_io_functions(self): with self.assertRaises(NotImplementedError): getattr(df.write, f)() + def test_sql_with_command(self): + # SPARK-42705: spark.sql should return values from the command. + self.assertEqual( + self.connect.sql("show functions").collect(), self.spark.sql("show functions").collect() + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ClientTests(unittest.TestCase):