diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index e64d612c53e87..a529f048029b5 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -19,12 +19,12 @@ import io import logging import typing +import urllib.parse import uuid import grpc # type: ignore -import pandas -import pandas as pd import pyarrow as pa +import pandas import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -42,6 +42,137 @@ logging.basicConfig(level=logging.INFO) +class ChannelBuilder: + """ + This is a helper class that is used to create a GRPC channel based on the given + connection string per the documentation of Spark Connect. + + .. versionadded:: 3.4.0 + + Examples + -------- + >>> cb = ChannelBuilder("sc://localhost") + ... cb.endpoint + "localhost:15002" + + >>> cb = ChannelBuilder("sc://localhost/;use_ssl=true;token=aaa") + ... cb.secure + True + """ + + PARAM_USE_SSL = "use_ssl" + PARAM_TOKEN = "token" + PARAM_USER_ID = "user_id" + + DEFAULT_PORT = 15002 + + def __init__(self, url: str) -> None: + # Explicitly check the scheme of the URL. + if url[:5] != "sc://": + raise AttributeError("URL scheme must be set to `sc`.") + # Rewrite the URL to use http as the scheme so that we can leverage + # Python's built-in parser. + tmp_url = "http" + url[2:] + self.url = urllib.parse.urlparse(tmp_url) + self.params: typing.Dict[str, str] = {} + if len(self.url.path) > 0 and self.url.path != "/": + raise AttributeError( + f"Path component for connection URI must be empty: {self.url.path}" + ) + self._extract_attributes() + + def _extract_attributes(self) -> None: + if len(self.url.params) > 0: + parts = self.url.params.split(";") + for p in parts: + kv = p.split("=") + if len(kv) != 2: + raise AttributeError(f"Parameter '{p}' is not a valid parameter key-value pair") + self.params[kv[0]] = urllib.parse.unquote(kv[1]) + + netloc = self.url.netloc.split(":") + if len(netloc) == 1: + self.host = netloc[0] + self.port = ChannelBuilder.DEFAULT_PORT + elif len(netloc) == 2: + self.host = netloc[0] + self.port = int(netloc[1]) + else: + raise AttributeError( + f"Target destination {self.url.netloc} does not match ':' pattern" + ) + + def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]: + """ + Builds the GRPC specific metadata list to be injected into the request. All + parameters will be converted to metadata except ones that are explicitly used + by the channel. + + Returns + ------- + A list of tuples (key, value) + """ + return [ + (k, self.params[k]) + for k in self.params + if k + not in [ + ChannelBuilder.PARAM_TOKEN, + ChannelBuilder.PARAM_USE_SSL, + ChannelBuilder.PARAM_USER_ID, + ] + ] + + @property + def secure(self) -> bool: + value = self.params.get(ChannelBuilder.PARAM_USE_SSL, "") + return value.lower() == "true" + + @property + def endpoint(self) -> str: + return f"{self.host}:{self.port}" + + def get(self, key: str) -> Any: + """ + Parameters + ---------- + key : str + Parameter key name. + + Returns + ------- + The parameter value if present, raises exception otherwise. + """ + return self.params[key] + + def to_channel(self) -> grpc.Channel: + """ + Applies the parameters of the connection string and creates a new + GRPC channel according to the configuration. + + Returns + ------- + GRPC Channel instance. + """ + destination = f"{self.host}:{self.port}" + if not self.secure: + if self.params.get(ChannelBuilder.PARAM_TOKEN, None) is not None: + raise AttributeError("Token based authentication cannot be used without TLS") + return grpc.insecure_channel(destination) + else: + # Default SSL Credentials. + opt_token = self.params.get(ChannelBuilder.PARAM_TOKEN, None) + # When a token is present, pass the token to the channel. + if opt_token is not None: + ssl_creds = grpc.ssl_channel_credentials() + composite_creds = grpc.composite_channel_credentials( + ssl_creds, grpc.access_token_call_credentials(opt_token) + ) + return grpc.secure_channel(destination, credentials=composite_creds) + else: + return grpc.secure_channel(destination, credentials=grpc.ssl_channel_credentials()) + + class MetricValue: def __init__(self, name: str, value: NumericType, type: str): self._name = name @@ -104,11 +235,25 @@ def fromProto(cls, pb: typing.Any) -> "AnalyzeResult": class RemoteSparkSession(object): """Conceptually the remote spark session that communicates with the server""" - def __init__(self, user_id: str, host: Optional[str] = None, port: int = 15002): - self._host = "localhost" if host is None else host - self._port = port + def __init__(self, user_id: str, connection_string: str = "sc://localhost"): + """ + Creates a new RemoteSparkSession for the Spark Connect interface. + + Parameters + ---------- + user_id : str + Unique User ID that is used to differentiate multiple users and + isolate their Spark Sessions. + connection_string: str + Connection string that is used to extract the connection parameters and configure + the GRPC connection. + """ + + # Parse the connection string. + self._builder = ChannelBuilder(connection_string) self._user_id = user_id - self._channel = grpc.insecure_channel(f"{self._host}:{self._port}") + + self._channel = self._builder.to_channel() self._stub = grpc_lib.SparkConnectServiceStub(self._channel) # Create the reader @@ -226,10 +371,12 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult: req.user_context.user_id = self._user_id req.plan.CopyFrom(plan) - resp = self._stub.AnalyzePlan(req) + resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) return AnalyzeResult.fromProto(resp) def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: + import pandas as pd + if b.batch is not None and len(b.batch.data) > 0: with pa.ipc.open_stream(b.batch.data) as rd: return rd.read_pandas() @@ -238,10 +385,12 @@ def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]: return None def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]: + import pandas as pd + m: Optional[pb2.Response.Metrics] = None result_dfs = [] - for b in self._stub.ExecutePlan(req): + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): if b.metrics is not None: m = b.metrics diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index e9a06f9c5457c..0f345d622771c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -18,6 +18,9 @@ import unittest import shutil import tempfile + +import grpc # type: ignore + from pyspark.testing.sqlutils import have_pandas if have_pandas: @@ -27,7 +30,7 @@ from pyspark.sql.types import StructType, StructField, LongType, StringType if have_pandas: - from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit from pyspark.sql.dataframe import DataFrame @@ -67,6 +70,7 @@ def setUpClass(cls: Any): @classmethod def tearDownClass(cls: Any) -> None: cls.spark_connect_clean_up_test_data() + ReusedPySparkTestCase.tearDownClass() @classmethod def spark_connect_load_test_data(cls: Any): @@ -167,6 +171,43 @@ def test_simple_datasource_read(self) -> None: self.assertEqual(len(expectResult), len(actualResult)) +class ChannelBuilderTests(ReusedPySparkTestCase): + def test_invalid_connection_strings(self): + invalid = [ + "scc://host:12", + "http://host", + "sc:/host:1234/path", + "sc://host/path", + "sc://host/;parm1;param2", + ] + for i in invalid: + self.assertRaises(AttributeError, ChannelBuilder, i) + + self.assertRaises(AttributeError, ChannelBuilder("sc://host/;token=123").to_channel) + + def test_valid_channel_creation(self): + chan = ChannelBuilder("sc://host").to_channel() + self.assertIsInstance(chan, grpc.Channel) + + # Sets up a channel without tokens because ssl is not used. + chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc").to_channel() + self.assertIsInstance(chan, grpc.Channel) + + chan = ChannelBuilder("sc://host/;use_ssl=true").to_channel() + self.assertIsInstance(chan, grpc.Channel) + + def test_channel_properties(self): + chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021") + self.assertEqual("host:15002", chan.endpoint) + self.assertEqual(True, chan.secure) + self.assertEqual("120 21", chan.get("param1")) + + def test_metadata(self): + chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd") + md = chan.metadata() + self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md) + + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401