diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 644784fa3db6c..9af2e7cb46616 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2992,6 +2992,14 @@ class SparkConnectPlanner( // the SQL command and defer the actual analysis and execution to the flow function. if (insidePipelineFlowFunction) { result.setRelation(relation) + executeHolder.eventsManager.postFinished() + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionHolder.sessionId) + .setServerSideSessionId(sessionHolder.serverSessionId) + .setSqlCommandResult(result) + .build) return } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index c9551646385c2..3cb45fa6e1720 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -850,4 +850,98 @@ class SparkDeclarativePipelinesServerSuite } } } + + test( + "SPARK-54452: spark.sql() inside a pipeline flow function should return a sql_command_result") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipelineAnalysisContext = proto.PipelineAnalysisContext + .newBuilder() + .setDataflowGraphId(graphId) + .setFlowName("flow1") + .build() + val userContext = proto.UserContext + .newBuilder() + .addExtensions(com.google.protobuf.Any.pack(pipelineAnalysisContext)) + .setUserId("test_user") + .build() + + val relation = proto.Plan + .newBuilder() + .setCommand( + proto.Command + .newBuilder() + .setSqlCommand( + proto.SqlCommand + .newBuilder() + .setInput( + proto.Relation + .newBuilder() + .setRead(proto.Read + .newBuilder() + .setNamedTable( + proto.Read.NamedTable.newBuilder().setUnparsedIdentifier("table")) + .build()) + .build())) + .build()) + .build() + + val sparkSqlRequest = proto.ExecutePlanRequest + .newBuilder() + .setUserContext(userContext) + .setPlan(relation) + .setSessionId(UUID.randomUUID().toString) + .build() + val sparkSqlResponse = stub.executePlan(sparkSqlRequest).next() + assert(sparkSqlResponse.hasSqlCommandResult) + assert( + sparkSqlResponse.getSqlCommandResult.getRelation == + relation.getCommand.getSqlCommand.getInput) + } + } + + test( + "SPARK-54452: spark.sql() outside a pipeline flow function should return a " + + "sql_command_result") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipelineAnalysisContext = proto.PipelineAnalysisContext + .newBuilder() + .setDataflowGraphId(graphId) + .build() + val userContext = proto.UserContext + .newBuilder() + .addExtensions(com.google.protobuf.Any.pack(pipelineAnalysisContext)) + .setUserId("test_user") + .build() + + val relation = proto.Plan + .newBuilder() + .setCommand( + proto.Command + .newBuilder() + .setSqlCommand( + proto.SqlCommand + .newBuilder() + .setInput(proto.Relation + .newBuilder() + .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM RANGE(5)")) + .build()) + .build()) + .build()) + .build() + + val sparkSqlRequest = proto.ExecutePlanRequest + .newBuilder() + .setUserContext(userContext) + .setPlan(relation) + .setSessionId(UUID.randomUUID().toString) + .build() + val sparkSqlResponse = stub.executePlan(sparkSqlRequest).next() + assert(sparkSqlResponse.hasSqlCommandResult) + assert( + sparkSqlResponse.getSqlCommandResult.getRelation == + relation.getCommand.getSqlCommand.getInput) + } + } }