diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto index 390a8b156dc4..dff1734335e2 100644 --- a/connector/connect/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/src/main/protobuf/spark/connect/base.proto @@ -19,6 +19,7 @@ syntax = 'proto3'; package spark.connect; +import "google/protobuf/any.proto"; import "spark/connect/commands.proto"; import "spark/connect/relations.proto"; @@ -51,6 +52,12 @@ message Request { message UserContext { string user_id = 1; string user_name = 2; + + // To extend the existing user context message that is used to identify incoming requests, + // Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + // messages into this message. Extensions are stored as a `repeated` type to be able to + // handle multiple active extensions. + repeated google.protobuf.Any extensions = 999; } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala new file mode 100644 index 000000000000..4132cca91086 --- /dev/null +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/messages/ConnectProtoMessagesSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.messages + +import org.apache.spark.SparkFunSuite +import org.apache.spark.connect.proto + +class ConnectProtoMessagesSuite extends SparkFunSuite { + test("UserContext can deal with extensions") { + // Create the builder. + val builder = proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin") + + // Create the extension value. + val lit = proto.Expression + .newBuilder() + .setLiteral(proto.Expression.Literal.newBuilder().setI32(32).build()) + // Pack the extension into Any. + val aval = com.google.protobuf.Any.pack(lit.build()) + // Add Any to the repeated field list. + builder.addExtensions(aval) + // Create serialized value. + val serialized = builder.build().toByteArray + + // Now, read the serialized value. + val result = proto.Request.UserContext.parseFrom(serialized) + assert(result.getUserId.equals("1")) + assert(result.getUserName.equals("Martin")) + assert(result.getExtensionsCount == 1) + + val ext = result.getExtensions(0) + assert(ext.is(classOf[proto.Expression])) + val extLit = ext.unpack(classOf[proto.Expression]) + assert(extLit.hasLiteral) + assert(extLit.getLiteral.hasI32) + assert(extLit.getLiteral.getI32 == 32) + } +} diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 8de6565bae86..408872dbb66e 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -28,12 +28,13 @@ _sym_db = _symbol_database.Default() +from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x43\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -44,28 +45,28 @@ DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001" - _PLAN._serialized_start = 104 - _PLAN._serialized_end = 220 - _REQUEST._serialized_start = 223 - _REQUEST._serialized_end = 442 - _REQUEST_USERCONTEXT._serialized_start = 375 - _REQUEST_USERCONTEXT._serialized_end = 442 - _RESPONSE._serialized_start = 445 - _RESPONSE._serialized_end = 1413 - _RESPONSE_ARROWBATCH._serialized_start = 674 - _RESPONSE_ARROWBATCH._serialized_end = 849 - _RESPONSE_JSONBATCH._serialized_start = 851 - _RESPONSE_JSONBATCH._serialized_end = 911 - _RESPONSE_METRICS._serialized_start = 914 - _RESPONSE_METRICS._serialized_end = 1398 - _RESPONSE_METRICS_METRICOBJECT._serialized_start = 998 - _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1308 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1196 - _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1308 - _RESPONSE_METRICS_METRICVALUE._serialized_start = 1310 - _RESPONSE_METRICS_METRICVALUE._serialized_end = 1398 - _ANALYZERESPONSE._serialized_start = 1416 - _ANALYZERESPONSE._serialized_end = 1571 - _SPARKCONNECTSERVICE._serialized_start = 1574 - _SPARKCONNECTSERVICE._serialized_end = 1736 + _PLAN._serialized_start = 131 + _PLAN._serialized_end = 247 + _REQUEST._serialized_start = 250 + _REQUEST._serialized_end = 524 + _REQUEST_USERCONTEXT._serialized_start = 402 + _REQUEST_USERCONTEXT._serialized_end = 524 + _RESPONSE._serialized_start = 527 + _RESPONSE._serialized_end = 1495 + _RESPONSE_ARROWBATCH._serialized_start = 756 + _RESPONSE_ARROWBATCH._serialized_end = 931 + _RESPONSE_JSONBATCH._serialized_start = 933 + _RESPONSE_JSONBATCH._serialized_end = 993 + _RESPONSE_METRICS._serialized_start = 996 + _RESPONSE_METRICS._serialized_end = 1480 + _RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080 + _RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1278 + _RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390 + _RESPONSE_METRICS_METRICVALUE._serialized_start = 1392 + _RESPONSE_METRICS_METRICVALUE._serialized_end = 1480 + _ANALYZERESPONSE._serialized_start = 1498 + _ANALYZERESPONSE._serialized_end = 1653 + _SPARKCONNECTSERVICE._serialized_start = 1656 + _SPARKCONNECTSERVICE._serialized_end = 1818 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index bcac7d11a808..bb3a6578cf71 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -35,6 +35,7 @@ limitations under the License. """ import builtins import collections.abc +import google.protobuf.any_pb2 import google.protobuf.descriptor import google.protobuf.internal.containers import google.protobuf.message @@ -102,17 +103,32 @@ class Request(google.protobuf.message.Message): USER_ID_FIELD_NUMBER: builtins.int USER_NAME_FIELD_NUMBER: builtins.int + EXTENSIONS_FIELD_NUMBER: builtins.int user_id: builtins.str user_name: builtins.str + @property + def extensions( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + google.protobuf.any_pb2.Any + ]: + """To extend the existing user context message that is used to identify incoming requests, + Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other + messages into this message. Extensions are stored as a `repeated` type to be able to + handle multiple active extensions. + """ def __init__( self, *, user_id: builtins.str = ..., user_name: builtins.str = ..., + extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., ) -> None: ... def ClearField( self, - field_name: typing_extensions.Literal["user_id", b"user_id", "user_name", b"user_name"], + field_name: typing_extensions.Literal[ + "extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name" + ], ) -> None: ... CLIENT_ID_FIELD_NUMBER: builtins.int