diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cccd5ebd33d6..367a5be33746 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -59,7 +59,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE -import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry +import org.apache.spark.sql.connect.plugin.{CommandPlugin, CommandPluginWithQueryPlanningTracker, SparkConnectPluginRegistry} import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, SparkConnectService} import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors @@ -2626,7 +2626,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { // Lazily traverse the collection. .view // Apply the transformation. - .map(p => p.process(extension, this)) + .map(p => + p match { + case p: CommandPluginWithQueryPlanningTracker => + val tracker = executeHolder.eventsManager.createQueryPlanningTracker + p.process(extension, this, tracker) + case p: CommandPlugin => p.process(extension, this) + }) // Find the first non-empty transformation or throw. .find(_.nonEmpty) .flatten diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/CommandPluginWithQueryPlanningTracker.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/CommandPluginWithQueryPlanningTracker.scala new file mode 100644 index 000000000000..f99cd1851d9c --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/plugin/CommandPluginWithQueryPlanningTracker.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.plugin + +import com.google.protobuf + +import org.apache.spark.sql.catalyst.QueryPlanningTracker +import org.apache.spark.sql.connect.planner.SparkConnectPlanner + +/** + * Behavior trait for supporting a trackable extension mechanisms for the Spark Connect planner. + * + * Classes implementing the trait must be trivially constructable and should not rely on internal + * state. Every registered extension will be passed the Any instance. If the plugin supports + * handling this type it is responsible of constructing the logical expression from this object + * and if necessary traverse it's children. + */ +trait CommandPluginWithQueryPlanningTracker extends CommandPlugin { + def process( + command: protobuf.Any, + planner: SparkConnectPlanner, + tracker: QueryPlanningTracker): Option[Unit] +} diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala index fdb903237941..6eb5548c918e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkContext, SparkEnv, SparkException} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Relation import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput @@ -96,6 +97,25 @@ class ExampleCommandPlugin extends CommandPlugin { } } +class ExampleCommandPluginWithQueryPlanningTracker extends CommandPluginWithQueryPlanningTracker { + override def process(command: protobuf.Any, planner: SparkConnectPlanner): Option[Unit] = { + throw new SparkException("This should not be called here") + } + + override def process( + command: protobuf.Any, + planner: SparkConnectPlanner, + tracker: QueryPlanningTracker): Option[Unit] = { + if (!command.is(classOf[proto.ExamplePluginCommand])) { + return None + } + val cmd = command.unpack(classOf[proto.ExamplePluginCommand]) + assert(planner.session != null) + SparkContext.getActive.get.setLocalProperty("testingProperty", cmd.getCustomField) + Some() + } +} + class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConnectPlanTest { override def beforeEach(): Unit = { @@ -202,6 +222,28 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne } } + test("SPARK-45204: End to end Command test - CommandPluginWithQueryPlanningTracker") { + withSparkConf( + Connect.CONNECT_EXTENSIONS_COMMAND_CLASSES.key -> + "org.apache.spark.sql.connect.plugin.ExampleCommandPluginWithQueryPlanningTracker") { + spark.sparkContext.setLocalProperty("testingProperty", "notset") + val plan = proto.Command + .newBuilder() + .setExtension( + protobuf.Any.pack( + proto.ExamplePluginCommand + .newBuilder() + .setCustomField("Robert") + .build())) + .build() + + val executeHolder = buildExecutePlanHolder(plan) + new SparkConnectPlanner(executeHolder.sessionHolder) + .process(plan, new MockObserver(), executeHolder) + assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Robert")) + } + } + test("Exception handling for plugin classes") { withSparkConf( Connect.CONNECT_EXTENSIONS_RELATION_CLASSES.key ->