Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union
from datetime import date, time, datetime

PrimitiveType = Union[str, int, bool, float]
LiteralType = Union[PrimitiveType, Union[date, time, datetime]]
import sys

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol

from typing import Union, Optional
import datetime
import decimal

from pyspark.sql.connect.column import ScalarFunctionExpression, Expression, Column
from pyspark.sql.connect.function_builder import UserDefinedFunction

ExpressionOrString = Union[Expression, str]

ColumnOrName = Union[Column, str]

PrimitiveType = Union[bool, float, int, str]

OptionalPrimitiveType = Optional[PrimitiveType]

LiteralType = PrimitiveType

DecimalLiteral = decimal.Decimal

DateTimeLiteral = Union[datetime.datetime, datetime.date]


class FunctionBuilderCallable(Protocol):
def __call__(self, *_: ExpressionOrString) -> ScalarFunctionExpression:
...


class UserDefinedFunctionCallable(Protocol):
def __call__(self, *_: ColumnOrName) -> UserDefinedFunction:
...
25 changes: 10 additions & 15 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import logging
import os
import typing
import urllib.parse
import uuid

Expand All @@ -35,9 +34,7 @@
from pyspark.sql.connect.plan import SQL, Range
from pyspark.sql.types import DataType, StructType, StructField, LongType, StringType

from typing import Optional, Any, Union

NumericType = typing.Union[int, float]
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -74,7 +71,7 @@ def __init__(self, url: str) -> None:
# Python's built-in parser.
tmp_url = "http" + url[2:]
self.url = urllib.parse.urlparse(tmp_url)
self.params: typing.Dict[str, str] = {}
self.params: Dict[str, str] = {}
if len(self.url.path) > 0 and self.url.path != "/":
raise AttributeError(
f"Path component for connection URI must be empty: {self.url.path}"
Expand Down Expand Up @@ -102,7 +99,7 @@ def _extract_attributes(self) -> None:
f"Target destination {self.url.netloc} does not match '<host>:<port>' pattern"
)

def metadata(self) -> typing.Iterable[typing.Tuple[str, str]]:
def metadata(self) -> Iterable[Tuple[str, str]]:
"""
Builds the GRPC specific metadata list to be injected into the request. All
parameters will be converted to metadata except ones that are explicitly used
Expand Down Expand Up @@ -198,7 +195,7 @@ def toChannel(self) -> grpc.Channel:


class MetricValue:
def __init__(self, name: str, value: NumericType, type: str):
def __init__(self, name: str, value: Union[int, float], type: str):
self._name = name
self._type = type
self._value = value
Expand All @@ -211,7 +208,7 @@ def name(self) -> str:
return self._name

@property
def value(self) -> NumericType:
def value(self) -> Union[int, float]:
return self._value

@property
Expand All @@ -220,7 +217,7 @@ def metric_type(self) -> str:


class PlanMetrics:
def __init__(self, name: str, id: int, parent: int, metrics: typing.List[MetricValue]):
def __init__(self, name: str, id: int, parent: int, metrics: List[MetricValue]):
self._name = name
self._id = id
self._parent_id = parent
Expand All @@ -242,7 +239,7 @@ def parent_plan_id(self) -> int:
return self._parent_id

@property
def metrics(self) -> typing.List[MetricValue]:
def metrics(self) -> List[MetricValue]:
return self._metrics


Expand All @@ -252,7 +249,7 @@ def __init__(self, schema: pb2.DataType, explain: str):
self.explain_string = explain

@classmethod
def fromProto(cls, pb: typing.Any) -> "AnalyzeResult":
def fromProto(cls, pb: Any) -> "AnalyzeResult":
return AnalyzeResult(pb.schema, pb.explain_string)


Expand Down Expand Up @@ -306,9 +303,7 @@ def register_udf(
self._execute_and_fetch(req)
return name

def _build_metrics(
self, metrics: "pb2.ExecutePlanResponse.Metrics"
) -> typing.List[PlanMetrics]:
def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
x.name,
Expand Down Expand Up @@ -450,7 +445,7 @@ def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFram
return rd.read_pandas()
return None

def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> Optional[pandas.DataFrame]:
import pandas as pd

m: Optional[pb2.ExecutePlanResponse.Metrics] = None
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import datetime

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect._typing import PrimitiveType

if TYPE_CHECKING:
from pyspark.sql.connect.client import RemoteSparkSession
Expand All @@ -33,6 +32,8 @@ def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["Column", Any], "Expression"]:
def _(self: "Column", other: Any) -> "Expression":
from pyspark.sql.connect._typing import PrimitiveType

if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
Expand Down Expand Up @@ -70,6 +71,8 @@ def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
from pyspark.sql.connect._typing import PrimitiveType

if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("==", self, other)
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
)

if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString, LiteralType
from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString, LiteralType
from pyspark.sql.connect.client import RemoteSparkSession

ColumnOrName = Union[Column, str]


class GroupingFrame(object):

Expand Down Expand Up @@ -308,7 +306,7 @@ def distinct(self) -> "DataFrame":
plan.Deduplicate(child=self._plan, all_columns_as_keys=True), session=self._session
)

def drop(self, *cols: "ColumnOrString") -> "DataFrame":
def drop(self, *cols: "ColumnOrName") -> "DataFrame":
_cols = list(cols)
if any(not isinstance(c, (str, Column)) for c in _cols):
raise TypeError(
Expand Down Expand Up @@ -342,7 +340,7 @@ def first(self) -> Optional[Row]:
"""
return self.head()

def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
def groupBy(self, *cols: "ColumnOrName") -> GroupingFrame:
return GroupingFrame(self, *cols)

@overload
Expand Down Expand Up @@ -414,13 +412,13 @@ def limit(self, n: int) -> "DataFrame":
def offset(self, n: int) -> "DataFrame":
return DataFrame.withPlan(plan.Offset(child=self._plan, offset=n), session=self._session)

def sort(self, *cols: "ColumnOrString") -> "DataFrame":
def sort(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=True), session=self._session
)

def sortWithinPartitions(self, *cols: "ColumnOrString") -> "DataFrame":
def sortWithinPartitions(self, *cols: "ColumnOrName") -> "DataFrame":
"""Sort within each partition by a specific column"""
return DataFrame.withPlan(
plan.Sort(self._plan, columns=list(cols), is_global=False), session=self._session
Expand Down
10 changes: 7 additions & 3 deletions python/pyspark/sql/connect/function_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@


if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
from pyspark.sql.connect._typing import (
ColumnOrName,
ExpressionOrString,
FunctionBuilderCallable,
UserDefinedFunctionCallable,
)
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.typing import FunctionBuilderCallable, UserDefinedFunctionCallable


def _build(name: str, *args: "ExpressionOrString") -> ScalarFunctionExpression:
Expand Down Expand Up @@ -103,7 +107,7 @@ def __str__(self) -> str:
def _create_udf(
function: Any, return_type: Union[str, pyspark.sql.types.DataType]
) -> "UserDefinedFunctionCallable":
def wrapper(*cols: "ColumnOrString") -> UserDefinedFunction:
def wrapper(*cols: "ColumnOrName") -> UserDefinedFunction:
return UserDefinedFunction(func=function, return_type=return_type, args=cols)

return wrapper
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@


if TYPE_CHECKING:
from pyspark.sql.connect.typing import ColumnOrString, ExpressionOrString
from pyspark.sql.connect._typing import ColumnOrName, ExpressionOrString
from pyspark.sql.connect.client import RemoteSparkSession


Expand All @@ -58,7 +58,7 @@ def unresolved_attr(self, colName: str) -> proto.Expression:
return exp

def to_attr_or_expression(
self, col: "ColumnOrString", session: "RemoteSparkSession"
self, col: "ColumnOrName", session: "RemoteSparkSession"
) -> proto.Expression:
"""Returns either an instance of an unresolved attribute or the serialized
expression value of the column."""
Expand Down
5 changes: 1 addition & 4 deletions python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@

from typing import Dict, Optional

from pyspark.sql.connect.column import PrimitiveType
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, DataSource
from pyspark.sql.utils import to_str


OptionalPrimitiveType = Optional[PrimitiveType]

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.client import RemoteSparkSession


Expand Down
35 changes: 0 additions & 35 deletions python/pyspark/sql/connect/typing/__init__.pyi

This file was deleted.