Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,15 @@ message Expression {
message CommonInlineUserDefinedFunction {
// (Required) Name of the user-defined function.
string function_name = 1;
// (Required) Indicate if the user-defined function is deterministic.
// (Optional) Indicate if the user-defined function is deterministic.
bool deterministic = 2;
Copy link
Member Author

@xinrong-meng xinrong-meng Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JavaUDF has no deterministic field but the server doesn't have logic that relies on that field. So only the comment is changed to minimize changes to proto.

// (Optional) Function arguments. Empty arguments are allowed.
repeated Expression arguments = 3;
// (Required) Indicate the function type of the user-defined function.
oneof function {
PythonUDF python_udf = 4;
ScalarScalaUDF scalar_scala_udf = 5;
JavaUDF java_udf = 6;
}
}

Expand All @@ -345,3 +346,13 @@ message ScalarScalaUDF {
bool nullable = 4;
}

message JavaUDF {
// (Required) Fully qualified name of Java class
string class_name = 1;

// (Optional) Output type of the Java UDF
optional string output_type = 2;

// (Required) Indicate if the Java user-defined function is an aggregate function
bool aggregate = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,8 @@ class SparkConnectPlanner(val session: SparkSession) {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
handleRegisterPythonUDF(fun)
case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF =>
handleRegisterJavaUDF(fun)
case _ =>
throw InvalidPlanInput(
s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
Expand All @@ -1575,6 +1577,25 @@ class SparkConnectPlanner(val session: SparkSession) {
session.udf.registerPython(fun.getFunctionName, udpf)
}

private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
val udf = fun.getJavaUdf
val dataType =
if (udf.hasOutputType) {
DataType.parseTypeWithFallback(
schema = udf.getOutputType,
parser = DataType.fromDDL,
fallbackParser = DataType.fromJson) match {
case s: DataType => s
case other => throw InvalidPlanInput(s"Invalid return type $other")
}
} else null
if (udf.getAggregate) {
session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName)
} else {
session.udf.registerJava(fun.getFunctionName, udf.getClassName, dataType)
}
}

private def handleCommandPlugin(extension: ProtoAny): Unit = {
SparkConnectPluginRegistry.commandRegistry
// Lazily traverse the collection.
Expand Down
39 changes: 37 additions & 2 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
Callable,
Generator,
Type,
TYPE_CHECKING,
)

import pandas as pd
Expand All @@ -69,6 +70,7 @@
from pyspark.sql.connect.expressions import (
PythonUDF,
CommonInlineUserDefinedFunction,
JavaUDF,
)
from pyspark.sql.connect.types import parse_data_type
from pyspark.sql.types import (
Expand All @@ -80,6 +82,10 @@
from pyspark.rdd import PythonEvalType


if TYPE_CHECKING:
from pyspark.sql.connect._typing import DataTypeOrString


def _configure_logging() -> logging.Logger:
"""Configure logging for the Spark Connect clients."""
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -534,7 +540,7 @@ def __init__(
def register_udf(
self,
function: Any,
return_type: Union[str, DataType],
return_type: "DataTypeOrString",
name: Optional[str] = None,
eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
Expand All @@ -561,9 +567,9 @@ def register_udf(
# construct a CommonInlineUserDefinedFunction
fun = CommonInlineUserDefinedFunction(
function_name=name,
deterministic=deterministic,
arguments=[],
function=py_udf,
deterministic=deterministic,
).to_plan_udf(self)

# construct the request
Expand All @@ -573,6 +579,35 @@ def register_udf(
self._execute(req)
return name

def register_java(
self,
name: str,
javaClassName: str,
return_type: Optional["DataTypeOrString"] = None,
aggregate: bool = False,
) -> None:
# convert str return_type to DataType
if isinstance(return_type, str):
return_type = parse_data_type(return_type)

# construct a JavaUDF
if return_type is None:
java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
else:
java_udf = JavaUDF(
class_name=javaClassName,
output_type=return_type.json(),
)
fun = CommonInlineUserDefinedFunction(
function_name=name,
function=java_udf,
).to_plan_judf(self)
# construct the request
req = self._execute_plan_request_with_metadata()
req.plan.command.register_function.CopyFrom(fun)

self._execute(req)

def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
Expand Down
44 changes: 39 additions & 5 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
check_dependencies(__name__, __file__)

from typing import (
cast,
TYPE_CHECKING,
Any,
Union,
Expand Down Expand Up @@ -520,16 +521,41 @@ def __repr__(self) -> str:
)


class JavaUDF:
"""Represents a Java (aggregate) user-defined function."""

def __init__(
self,
class_name: str,
output_type: Optional[str] = None,
aggregate: bool = False,
) -> None:
self._class_name = class_name
self._output_type = output_type
self._aggregate = aggregate

def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
expr = proto.JavaUDF()
expr.class_name = self._class_name
if self._output_type is not None:
expr.output_type = self._output_type
expr.aggregate = self._aggregate
return expr

def __repr__(self) -> str:
return f"{self._class_name}, {self._output_type}"


class CommonInlineUserDefinedFunction(Expression):
"""Represents a user-defined function with an inlined defined function body of any programming
languages."""

def __init__(
self,
function_name: str,
deterministic: bool,
arguments: Sequence[Expression],
function: PythonUDF,
function: Union[PythonUDF, JavaUDF],
deterministic: bool = False,
arguments: Sequence[Expression] = [],
):
self._function_name = function_name
self._deterministic = deterministic
Expand All @@ -545,7 +571,7 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
[arg.to_plan(session) for arg in self._arguments]
)
expr.common_inline_user_defined_function.python_udf.CopyFrom(
self._function.to_plan(session)
cast(proto.PythonUDF, self._function.to_plan(session))
)
return expr

Expand All @@ -557,7 +583,15 @@ def to_plan_udf(self, session: "SparkConnectClient") -> "proto.CommonInlineUserD
expr.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.arguments.extend([arg.to_plan(session) for arg in self._arguments])
expr.python_udf.CopyFrom(self._function.to_plan(session))
expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session)))
return expr

def to_plan_judf(
self, session: "SparkConnectClient"
) -> "proto.CommonInlineUserDefinedFunction":
expr = proto.CommonInlineUserDefinedFunction()
expr.function_name = self._function_name
expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session)))
return expr

def __repr__(self) -> str:
Expand Down
26 changes: 20 additions & 6 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

56 changes: 54 additions & 2 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1171,10 +1171,11 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
ARGUMENTS_FIELD_NUMBER: builtins.int
PYTHON_UDF_FIELD_NUMBER: builtins.int
SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
JAVA_UDF_FIELD_NUMBER: builtins.int
function_name: builtins.str
"""(Required) Name of the user-defined function."""
deterministic: builtins.bool
"""(Required) Indicate if the user-defined function is deterministic."""
"""(Optional) Indicate if the user-defined function is deterministic."""
@property
def arguments(
self,
Expand All @@ -1184,6 +1185,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
def python_udf(self) -> global___PythonUDF: ...
@property
def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
@property
def java_udf(self) -> global___JavaUDF: ...
def __init__(
self,
*,
Expand All @@ -1192,12 +1195,15 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
arguments: collections.abc.Iterable[global___Expression] | None = ...,
python_udf: global___PythonUDF | None = ...,
scalar_scala_udf: global___ScalarScalaUDF | None = ...,
java_udf: global___JavaUDF | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"function",
b"function",
"java_udf",
b"java_udf",
"python_udf",
b"python_udf",
"scalar_scala_udf",
Expand All @@ -1215,6 +1221,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
b"function",
"function_name",
b"function_name",
"java_udf",
b"java_udf",
"python_udf",
b"python_udf",
"scalar_scala_udf",
Expand All @@ -1223,7 +1231,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["function", b"function"]
) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ...
) -> typing_extensions.Literal["python_udf", "scalar_scala_udf", "java_udf"] | None: ...

global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction

Expand Down Expand Up @@ -1314,3 +1322,47 @@ class ScalarScalaUDF(google.protobuf.message.Message):
) -> None: ...

global___ScalarScalaUDF = ScalarScalaUDF

class JavaUDF(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

CLASS_NAME_FIELD_NUMBER: builtins.int
OUTPUT_TYPE_FIELD_NUMBER: builtins.int
AGGREGATE_FIELD_NUMBER: builtins.int
class_name: builtins.str
"""(Required) Fully qualified name of Java class"""
output_type: builtins.str
"""(Optional) Output type of the Java UDF"""
aggregate: builtins.bool
"""(Required) Indicate if the Java user-defined function is an aggregate function"""
def __init__(
self,
*,
class_name: builtins.str = ...,
output_type: builtins.str | None = ...,
aggregate: builtins.bool = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_output_type", b"_output_type", "output_type", b"output_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_output_type",
b"_output_type",
"aggregate",
b"aggregate",
"class_name",
b"class_name",
"output_type",
b"output_type",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_output_type", b"_output_type"]
) -> typing_extensions.Literal["output_type"] | None: ...

global___JavaUDF = JavaUDF
17 changes: 16 additions & 1 deletion python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _build_common_inline_user_defined_function(
)
return CommonInlineUserDefinedFunction(
function_name=self._name,
function=py_udf,
deterministic=self.deterministic,
arguments=arg_exprs,
function=py_udf,
)

def __call__(self, *cols: "ColumnOrName") -> Column:
Expand Down Expand Up @@ -232,3 +232,18 @@ def register(
return return_udf

register.__doc__ = PySparkUDFRegistration.register.__doc__

def registerJavaFunction(
self,
name: str,
javaClassName: str,
returnType: Optional["DataTypeOrString"] = None,
) -> None:
self.sparkSession._client.register_java(name, javaClassName, returnType)

registerJavaFunction.__doc__ = PySparkUDFRegistration.registerJavaFunction.__doc__

def registerJavaUDAF(self, name: str, javaClassName: str) -> None:
self.sparkSession._client.register_java(name, javaClassName, aggregate=True)

registerJavaUDAF.__doc__ = PySparkUDFRegistration.registerJavaUDAF.__doc__
Loading