Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 157 additions & 8 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
import io
import logging
import typing
import urllib.parse
import uuid

import grpc # type: ignore
import pandas
import pandas as pd
import pyarrow as pa
import pandas

import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
Expand All @@ -42,6 +42,137 @@
logging.basicConfig(level=logging.INFO)


class ChannelBuilder:
"""
This is a helper class that is used to create a GRPC channel based on the given
connection string per the documentation of Spark Connect.

.. versionadded:: 3.4.0

Examples
--------
>>> cb = ChannelBuilder("sc://localhost")
... cb.endpoint
"localhost:15002"

>>> cb = ChannelBuilder("sc://localhost/;use_ssl=true;token=aaa")
... cb.secure
True
"""

PARAM_USE_SSL = "use_ssl"
PARAM_TOKEN = "token"
PARAM_USER_ID = "user_id"

DEFAULT_PORT = 15002

def __init__(self, url: str) -> None:
# Explicitly check the scheme of the URL.
if url[:5] != "sc://":
raise AttributeError("URL scheme must be set to `sc`.")
# Rewrite the URL to use http as the scheme so that we can leverage
# Python's built-in parser.
tmp_url = "http" + url[2:]
self.url = urllib.parse.urlparse(tmp_url)
self.params: typing.Dict[str, str] = {}
if len(self.url.path) > 0 and self.url.path != "/":
raise AttributeError(
f"Path component for connection URI must be empty: {self.url.path}"
)
self._extract_attributes()

def _extract_attributes(self) -> None:
if len(self.url.params) > 0:
parts = self.url.params.split(";")
for p in parts:
kv = p.split("=")
if len(kv) != 2:
raise AttributeError(f"Parameter '{p}' is not a valid parameter key-value pair")
self.params[kv[0]] = urllib.parse.unquote(kv[1])

netloc = self.url.netloc.split(":")
if len(netloc) == 1:
self.host = netloc[0]
self.port = ChannelBuilder.DEFAULT_PORT
elif len(netloc) == 2:
self.host = netloc[0]
self.port = int(netloc[1])
else:
raise AttributeError(
f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern"
)

def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]:
"""
Builds the GRPC specific metadata list to be injected into the request. All
parameters will be converted to metadata except ones that are explicitly used
by the channel.

Returns
-------
A list of tuples (key, value)
"""
return [
(k, self.params[k])
for k in self.params
if k
not in [
ChannelBuilder.PARAM_TOKEN,
ChannelBuilder.PARAM_USE_SSL,
ChannelBuilder.PARAM_USER_ID,
]
]

@property
def secure(self) -> bool:
value = self.params.get(ChannelBuilder.PARAM_USE_SSL, "")
return value.lower() == "true"

@property
def endpoint(self) -> str:
return f"{self.host}:{self.port}"

def get(self, key: str) -> Any:
"""
Parameters
----------
key : str
Parameter key name.

Returns
-------
The parameter value if present, raises exception otherwise.
"""
return self.params[key]

def to_channel(self) -> grpc.Channel:
"""
Applies the parameters of the connection string and creates a new
GRPC channel according to the configuration.

Returns
-------
GRPC Channel instance.
"""
destination = f"{self.host}:{self.port}"
if not self.secure:
if self.params.get(ChannelBuilder.PARAM_TOKEN, None) is not None:
raise AttributeError("Token based authentication cannot be used without TLS")
return grpc.insecure_channel(destination)
else:
# Default SSL Credentials.
opt_token = self.params.get(ChannelBuilder.PARAM_TOKEN, None)
# When a token is present, pass the token to the channel.
if opt_token is not None:
ssl_creds = grpc.ssl_channel_credentials()
composite_creds = grpc.composite_channel_credentials(
ssl_creds, grpc.access_token_call_credentials(opt_token)
)
return grpc.secure_channel(destination, credentials=composite_creds)
else:
return grpc.secure_channel(destination, credentials=grpc.ssl_channel_credentials())


class MetricValue:
def __init__(self, name: str, value: NumericType, type: str):
self._name = name
Expand Down Expand Up @@ -104,11 +235,25 @@ def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
class RemoteSparkSession(object):
"""Conceptually the remote spark session that communicates with the server"""

def __init__(self, user_id: str, host: Optional[str] = None, port: int = 15002):
self._host = "localhost" if host is None else host
self._port = port
def __init__(self, user_id: str, connection_string: str = "sc://localhost"):
"""
Creates a new RemoteSparkSession for the Spark Connect interface.

Parameters
----------
user_id : str
Unique User ID that is used to differentiate multiple users and
isolate their Spark Sessions.
connection_string: str
Connection string that is used to extract the connection parameters and configure
the GRPC connection.
"""

# Parse the connection string.
self._builder = ChannelBuilder(connection_string)
self._user_id = user_id
self._channel = grpc.insecure_channel(f"{self._host}:{self._port}")

self._channel = self._builder.to_channel()
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)

# Create the reader
Expand Down Expand Up @@ -226,10 +371,12 @@ def _analyze(self, plan: pb2.Plan) -> AnalyzeResult:
req.user_context.user_id = self._user_id
req.plan.CopyFrom(plan)

resp = self._stub.AnalyzePlan(req)
resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
return AnalyzeResult.fromProto(resp)

def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]:
import pandas as pd

if b.batch is not None and len(b.batch.data) > 0:
with pa.ipc.open_stream(b.batch.data) as rd:
return rd.read_pandas()
Expand All @@ -238,10 +385,12 @@ def _process_batch(self, b: pb2.Response) -> Optional[pandas.DataFrame]:
return None

def _execute_and_fetch(self, req: pb2.Request) -> typing.Optional[pandas.DataFrame]:
import pandas as pd

m: Optional[pb2.Response.Metrics] = None
result_dfs = []

for b in self._stub.ExecutePlan(req):
for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()):
if b.metrics is not None:
m = b.metrics

Expand Down
43 changes: 42 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import unittest
import shutil
import tempfile

import grpc # type: ignore

from pyspark.testing.sqlutils import have_pandas

if have_pandas:
Expand All @@ -27,7 +30,7 @@
from pyspark.sql.types import StructType, StructField, LongType, StringType

if have_pandas:
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit
from pyspark.sql.dataframe import DataFrame
Expand Down Expand Up @@ -67,6 +70,7 @@ def setUpClass(cls: Any):
@classmethod
def tearDownClass(cls: Any) -> None:
cls.spark_connect_clean_up_test_data()
ReusedPySparkTestCase.tearDownClass()

@classmethod
def spark_connect_load_test_data(cls: Any):
Expand Down Expand Up @@ -167,6 +171,43 @@ def test_simple_datasource_read(self) -> None:
self.assertEqual(len(expectResult), len(actualResult))


class ChannelBuilderTests(ReusedPySparkTestCase):
Copy link
Member

@dongjoon-hyun dongjoon-hyun Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be skipped by should_test_connect like SparkConnectSQLTestCase in this file.

@unittest.skipIf(not should_test_connect, connect_requirement_message)
class SparkConnectSQLTestCase(ReusedPySparkTestCase):

I made a PR for that.

def test_invalid_connection_strings(self):
invalid = [
"scc://host:12",
"http://host",
"sc:/host:1234/path",
"sc://host/path",
"sc://host/;parm1;param2",
]
for i in invalid:
self.assertRaises(AttributeError, ChannelBuilder, i)

self.assertRaises(AttributeError, ChannelBuilder("sc://host/;token=123").to_channel)

def test_valid_channel_creation(self):
chan = ChannelBuilder("sc://host").to_channel()
self.assertIsInstance(chan, grpc.Channel)

# Sets up a channel without tokens because ssl is not used.
chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc").to_channel()
self.assertIsInstance(chan, grpc.Channel)

chan = ChannelBuilder("sc://host/;use_ssl=true").to_channel()
self.assertIsInstance(chan, grpc.Channel)

def test_channel_properties(self):
chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021")
self.assertEqual("host:15002", chan.endpoint)
self.assertEqual(True, chan.secure)
self.assertEqual("120 21", chan.get("param1"))

def test_metadata(self):
chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021;x-my-header=abcd")
md = chan.metadata()
self.assertEqual([("param1", "120 21"), ("x-my-header", "abcd")], md)


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401

Expand Down