Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/base.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ syntax = 'proto3';

package spark.connect;

import "google/protobuf/any.proto";
import "spark/connect/commands.proto";
import "spark/connect/relations.proto";

Expand Down Expand Up @@ -51,6 +52,12 @@ message Request {
message UserContext {
string user_id = 1;
string user_name = 2;

// To extend the existing user context message that is used to identify incoming requests,
// Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
// messages into this message. Extensions are stored as a `repeated` type to be able to
// handle multiple active extensions.
repeated google.protobuf.Any extensions = 999;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my self education:

Is Any type potentially a cause of security holes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No any is just used to define extension points in the proto so that at the point of writing the proto we don't need to define message that is embedded in the proto.

Interpreting the serialized class is up to the consumer and follows the same security model as any other serialized code so we have to be careful but it's not a new threat.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Makes sense.

}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.messages

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto

class ConnectProtoMessagesSuite extends SparkFunSuite {
test("UserContext can deal with extensions") {
// Create the builder.
val builder = proto.Request.UserContext.newBuilder().setUserId("1").setUserName("Martin")

// Create the extension value.
val lit = proto.Expression
.newBuilder()
.setLiteral(proto.Expression.Literal.newBuilder().setI32(32).build())
// Pack the extension into Any.
val aval = com.google.protobuf.Any.pack(lit.build())
// Add Any to the repeated field list.
builder.addExtensions(aval)
// Create serialized value.
val serialized = builder.build().toByteArray

// Now, read the serialized value.
val result = proto.Request.UserContext.parseFrom(serialized)
assert(result.getUserId.equals("1"))
assert(result.getUserName.equals("Martin"))
assert(result.getExtensionsCount == 1)

val ext = result.getExtensions(0)
assert(ext.is(classOf[proto.Expression]))
val extLit = ext.unpack(classOf[proto.Expression])
assert(extLit.hasLiteral)
assert(extLit.getLiteral.hasI32)
assert(extLit.getLiteral.getI32 == 32)
}
}
51 changes: 26 additions & 25 deletions python/pyspark/sql/connect/proto/base_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
_sym_db = _symbol_database.Default()


from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from pyspark.sql.connect.proto import commands_pb2 as spark_dot_connect_dot_commands__pb2
from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xdb\x01\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x43\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\x92\x02\n\x07Request\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12\x45\n\x0cuser_context\x18\x02 \x01(\x0b\x32".spark.connect.Request.UserContextR\x0buserContext\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1az\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xc8\x07\n\x08Response\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12:\n\x05\x62\x61tch\x18\x02 \x01(\x0b\x32".spark.connect.Response.ArrowBatchH\x00R\x05\x62\x61tch\x12\x42\n\njson_batch\x18\x03 \x01(\x0b\x32!.spark.connect.Response.JSONBatchH\x00R\tjsonBatch\x12\x39\n\x07metrics\x18\x04 \x01(\x0b\x32\x1f.spark.connect.Response.MetricsR\x07metrics\x1a\xaf\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12-\n\x12uncompressed_bytes\x18\x02 \x01(\x03R\x11uncompressedBytes\x12)\n\x10\x63ompressed_bytes\x18\x03 \x01(\x03R\x0f\x63ompressedBytes\x12\x12\n\x04\x64\x61ta\x18\x04 \x01(\x0cR\x04\x64\x61ta\x12\x16\n\x06schema\x18\x05 \x01(\x0cR\x06schema\x1a<\n\tJSONBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x1a\xe4\x03\n\x07Metrics\x12\x46\n\x07metrics\x18\x01 \x03(\x0b\x32,.spark.connect.Response.Metrics.MetricObjectR\x07metrics\x1a\xb6\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12o\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32\x42.spark.connect.Response.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1ap\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32+.spark.connect.Response.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricTypeB\r\n\x0bresult_type"\x9b\x01\n\x0f\x41nalyzeResponse\x12\x1b\n\tclient_id\x18\x01 \x01(\tR\x08\x63lientId\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12!\n\x0c\x63olumn_types\x18\x03 \x03(\tR\x0b\x63olumnTypes\x12%\n\x0e\x65xplain_string\x18\x04 \x01(\tR\rexplainString2\xa2\x01\n\x13SparkConnectService\x12\x42\n\x0b\x45xecutePlan\x12\x16.spark.connect.Request\x1a\x17.spark.connect.Response"\x00\x30\x01\x12G\n\x0b\x41nalyzePlan\x12\x16.spark.connect.Request\x1a\x1e.spark.connect.AnalyzeResponse"\x00\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
Expand All @@ -44,28 +45,28 @@
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options = None
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options = b"8\001"
_PLAN._serialized_start = 104
_PLAN._serialized_end = 220
_REQUEST._serialized_start = 223
_REQUEST._serialized_end = 442
_REQUEST_USERCONTEXT._serialized_start = 375
_REQUEST_USERCONTEXT._serialized_end = 442
_RESPONSE._serialized_start = 445
_RESPONSE._serialized_end = 1413
_RESPONSE_ARROWBATCH._serialized_start = 674
_RESPONSE_ARROWBATCH._serialized_end = 849
_RESPONSE_JSONBATCH._serialized_start = 851
_RESPONSE_JSONBATCH._serialized_end = 911
_RESPONSE_METRICS._serialized_start = 914
_RESPONSE_METRICS._serialized_end = 1398
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 998
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1308
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1196
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1308
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1310
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1398
_ANALYZERESPONSE._serialized_start = 1416
_ANALYZERESPONSE._serialized_end = 1571
_SPARKCONNECTSERVICE._serialized_start = 1574
_SPARKCONNECTSERVICE._serialized_end = 1736
_PLAN._serialized_start = 131
_PLAN._serialized_end = 247
_REQUEST._serialized_start = 250
_REQUEST._serialized_end = 524
_REQUEST_USERCONTEXT._serialized_start = 402
_REQUEST_USERCONTEXT._serialized_end = 524
_RESPONSE._serialized_start = 527
_RESPONSE._serialized_end = 1495
_RESPONSE_ARROWBATCH._serialized_start = 756
_RESPONSE_ARROWBATCH._serialized_end = 931
_RESPONSE_JSONBATCH._serialized_start = 933
_RESPONSE_JSONBATCH._serialized_end = 993
_RESPONSE_METRICS._serialized_start = 996
_RESPONSE_METRICS._serialized_end = 1480
_RESPONSE_METRICS_METRICOBJECT._serialized_start = 1080
_RESPONSE_METRICS_METRICOBJECT._serialized_end = 1390
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1278
_RESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1390
_RESPONSE_METRICS_METRICVALUE._serialized_start = 1392
_RESPONSE_METRICS_METRICVALUE._serialized_end = 1480
_ANALYZERESPONSE._serialized_start = 1498
_ANALYZERESPONSE._serialized_end = 1653
_SPARKCONNECTSERVICE._serialized_start = 1656
_SPARKCONNECTSERVICE._serialized_end = 1818
# @@protoc_insertion_point(module_scope)
18 changes: 17 additions & 1 deletion python/pyspark/sql/connect/proto/base_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
"""
import builtins
import collections.abc
import google.protobuf.any_pb2
import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.message
Expand Down Expand Up @@ -102,17 +103,32 @@ class Request(google.protobuf.message.Message):

USER_ID_FIELD_NUMBER: builtins.int
USER_NAME_FIELD_NUMBER: builtins.int
EXTENSIONS_FIELD_NUMBER: builtins.int
user_id: builtins.str
user_name: builtins.str
@property
def extensions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
google.protobuf.any_pb2.Any
]:
"""To extend the existing user context message that is used to identify incoming requests,
Spark Connect leverages the Any protobuf type that can be used to inject arbitrary other
messages into this message. Extensions are stored as a `repeated` type to be able to
handle multiple active extensions.
"""
def __init__(
self,
*,
user_id: builtins.str = ...,
user_name: builtins.str = ...,
extensions: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal["user_id", b"user_id", "user_name", b"user_name"],
field_name: typing_extensions.Literal[
"extensions", b"extensions", "user_id", b"user_id", "user_name", b"user_name"
],
) -> None: ...

CLIENT_ID_FIELD_NUMBER: builtins.int
Expand Down