diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 6eb769ad27e0..0aee3ca13b9e 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -312,7 +312,7 @@ message Expression { message CommonInlineUserDefinedFunction { // (Required) Name of the user-defined function. string function_name = 1; - // (Required) Indicate if the user-defined function is deterministic. + // (Optional) Indicate if the user-defined function is deterministic. bool deterministic = 2; // (Optional) Function arguments. Empty arguments are allowed. repeated Expression arguments = 3; @@ -320,6 +320,7 @@ message CommonInlineUserDefinedFunction { oneof function { PythonUDF python_udf = 4; ScalarScalaUDF scalar_scala_udf = 5; + JavaUDF java_udf = 6; } } @@ -345,3 +346,13 @@ message ScalarScalaUDF { bool nullable = 4; } +message JavaUDF { + // (Required) Fully qualified name of Java class + string class_name = 1; + + // (Optional) Output type of the Java UDF + optional string output_type = 2; + + // (Required) Indicate if the Java user-defined function is an aggregate function + bool aggregate = 3; +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8ca004d520c5..3d921f911cdd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1550,6 +1550,8 @@ class SparkConnectPlanner(val session: SparkSession) { fun.getFunctionCase match { case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF => handleRegisterPythonUDF(fun) + case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF => + handleRegisterJavaUDF(fun) case _ => throw InvalidPlanInput( s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported") @@ -1575,6 +1577,25 @@ class SparkConnectPlanner(val session: SparkSession) { session.udf.registerPython(fun.getFunctionName, udpf) } + private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = { + val udf = fun.getJavaUdf + val dataType = + if (udf.hasOutputType) { + DataType.parseTypeWithFallback( + schema = udf.getOutputType, + parser = DataType.fromDDL, + fallbackParser = DataType.fromJson) match { + case s: DataType => s + case other => throw InvalidPlanInput(s"Invalid return type $other") + } + } else null + if (udf.getAggregate) { + session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName) + } else { + session.udf.registerJava(fun.getFunctionName, udf.getClassName, dataType) + } + } + private def handleCommandPlugin(extension: ProtoAny): Unit = { SparkConnectPluginRegistry.commandRegistry // Lazily traverse the collection. diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index 8c85f17bb5fd..6334036fca4a 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -47,6 +47,7 @@ Callable, Generator, Type, + TYPE_CHECKING, ) import pandas as pd @@ -69,6 +70,7 @@ from pyspark.sql.connect.expressions import ( PythonUDF, CommonInlineUserDefinedFunction, + JavaUDF, ) from pyspark.sql.connect.types import parse_data_type from pyspark.sql.types import ( @@ -80,6 +82,10 @@ from pyspark.rdd import PythonEvalType +if TYPE_CHECKING: + from pyspark.sql.connect._typing import DataTypeOrString + + def _configure_logging() -> logging.Logger: """Configure logging for the Spark Connect clients.""" logger = logging.getLogger(__name__) @@ -534,7 +540,7 @@ def __init__( def register_udf( self, function: Any, - return_type: Union[str, DataType], + return_type: "DataTypeOrString", name: Optional[str] = None, eval_type: int = PythonEvalType.SQL_BATCHED_UDF, deterministic: bool = True, @@ -561,9 +567,9 @@ def register_udf( # construct a CommonInlineUserDefinedFunction fun = CommonInlineUserDefinedFunction( function_name=name, - deterministic=deterministic, arguments=[], function=py_udf, + deterministic=deterministic, ).to_plan_udf(self) # construct the request @@ -573,6 +579,35 @@ def register_udf( self._execute(req) return name + def register_java( + self, + name: str, + javaClassName: str, + return_type: Optional["DataTypeOrString"] = None, + aggregate: bool = False, + ) -> None: + # convert str return_type to DataType + if isinstance(return_type, str): + return_type = parse_data_type(return_type) + + # construct a JavaUDF + if return_type is None: + java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate) + else: + java_udf = JavaUDF( + class_name=javaClassName, + output_type=return_type.json(), + ) + fun = CommonInlineUserDefinedFunction( + function_name=name, + function=java_udf, + ).to_plan_judf(self) + # construct the request + req = self._execute_plan_request_with_metadata() + req.plan.command.register_function.CopyFrom(fun) + + self._execute(req) + def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]: return [ PlanMetrics( diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 2b1901167c14..0d059740032c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -19,6 +19,7 @@ check_dependencies(__name__, __file__) from typing import ( + cast, TYPE_CHECKING, Any, Union, @@ -520,6 +521,31 @@ def __repr__(self) -> str: ) +class JavaUDF: + """Represents a Java (aggregate) user-defined function.""" + + def __init__( + self, + class_name: str, + output_type: Optional[str] = None, + aggregate: bool = False, + ) -> None: + self._class_name = class_name + self._output_type = output_type + self._aggregate = aggregate + + def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF: + expr = proto.JavaUDF() + expr.class_name = self._class_name + if self._output_type is not None: + expr.output_type = self._output_type + expr.aggregate = self._aggregate + return expr + + def __repr__(self) -> str: + return f"{self._class_name}, {self._output_type}" + + class CommonInlineUserDefinedFunction(Expression): """Represents a user-defined function with an inlined defined function body of any programming languages.""" @@ -527,9 +553,9 @@ class CommonInlineUserDefinedFunction(Expression): def __init__( self, function_name: str, - deterministic: bool, - arguments: Sequence[Expression], - function: PythonUDF, + function: Union[PythonUDF, JavaUDF], + deterministic: bool = False, + arguments: Sequence[Expression] = [], ): self._function_name = function_name self._deterministic = deterministic @@ -545,7 +571,7 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": [arg.to_plan(session) for arg in self._arguments] ) expr.common_inline_user_defined_function.python_udf.CopyFrom( - self._function.to_plan(session) + cast(proto.PythonUDF, self._function.to_plan(session)) ) return expr @@ -557,7 +583,15 @@ def to_plan_udf(self, session: "SparkConnectClient") -> "proto.CommonInlineUserD expr.deterministic = self._deterministic if len(self._arguments) > 0: expr.arguments.extend([arg.to_plan(session) for arg in self._arguments]) - expr.python_udf.CopyFrom(self._function.to_plan(session)) + expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session))) + return expr + + def to_plan_judf( + self, session: "SparkConnectClient" + ) -> "proto.CommonInlineUserDefinedFunction": + expr = proto.CommonInlineUserDefinedFunction() + expr.function_name = self._function_name + expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session))) return expr def __repr__(self) -> str: diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index d0db2ad56cc5..24dd11364801 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -34,7 +34,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xae\x08\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x7f\n\x05\x41rray\x12\x39\n\x0b\x65lementType\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12;\n\x07\x65lement\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07\x65lementB\x0e\n\x0cliteral_type\x1ap\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1aR\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x42\x12\n\x10_unparsed_target\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xb7\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdfB\n\n\x08\x66unction"\x82\x01\n\tPythonUDF\x12\x1f\n\x0boutput_type\x18\x01 \x01(\tR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullableB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\xae\x08\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x7f\n\x05\x41rray\x12\x39\n\x0b\x65lementType\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12;\n\x07\x65lement\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07\x65lementB\x0e\n\x0cliteral_type\x1ap\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1aR\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x42\x12\n\x10_unparsed_target\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xec\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdfB\n\n\x08\x66unction"\x82\x01\n\tPythonUDF\x12\x1f\n\x0boutput_type\x18\x01 \x01(\tR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable"|\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12$\n\x0boutput_type\x18\x02 \x01(\tH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_typeB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -67,6 +67,7 @@ ] _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"] _SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"] +_JAVAUDF = DESCRIPTOR.message_types_by_name["JavaUDF"] _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[ "FrameType" ] @@ -306,6 +307,17 @@ ) _sym_db.RegisterMessage(ScalarScalaUDF) +JavaUDF = _reflection.GeneratedProtocolMessageType( + "JavaUDF", + (_message.Message,), + { + "DESCRIPTOR": _JAVAUDF, + "__module__": "spark.connect.expressions_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.JavaUDF) + }, +) +_sym_db.RegisterMessage(JavaUDF) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -357,9 +369,11 @@ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062 _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124 _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5451 - _PYTHONUDF._serialized_start = 5454 - _PYTHONUDF._serialized_end = 5584 - _SCALARSCALAUDF._serialized_start = 5587 - _SCALARSCALAUDF._serialized_end = 5771 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504 + _PYTHONUDF._serialized_start = 5507 + _PYTHONUDF._serialized_end = 5637 + _SCALARSCALAUDF._serialized_start = 5640 + _SCALARSCALAUDF._serialized_end = 5824 + _JAVAUDF._serialized_start = 5826 + _JAVAUDF._serialized_end = 5950 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 37db24ff91ae..19b47c7ab91c 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1171,10 +1171,11 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): ARGUMENTS_FIELD_NUMBER: builtins.int PYTHON_UDF_FIELD_NUMBER: builtins.int SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int + JAVA_UDF_FIELD_NUMBER: builtins.int function_name: builtins.str """(Required) Name of the user-defined function.""" deterministic: builtins.bool - """(Required) Indicate if the user-defined function is deterministic.""" + """(Optional) Indicate if the user-defined function is deterministic.""" @property def arguments( self, @@ -1184,6 +1185,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): def python_udf(self) -> global___PythonUDF: ... @property def scalar_scala_udf(self) -> global___ScalarScalaUDF: ... + @property + def java_udf(self) -> global___JavaUDF: ... def __init__( self, *, @@ -1192,12 +1195,15 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): arguments: collections.abc.Iterable[global___Expression] | None = ..., python_udf: global___PythonUDF | None = ..., scalar_scala_udf: global___ScalarScalaUDF | None = ..., + java_udf: global___JavaUDF | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ "function", b"function", + "java_udf", + b"java_udf", "python_udf", b"python_udf", "scalar_scala_udf", @@ -1215,6 +1221,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): b"function", "function_name", b"function_name", + "java_udf", + b"java_udf", "python_udf", b"python_udf", "scalar_scala_udf", @@ -1223,7 +1231,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message): ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["function", b"function"] - ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ... + ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf", "java_udf"] | None: ... global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction @@ -1314,3 +1322,47 @@ class ScalarScalaUDF(google.protobuf.message.Message): ) -> None: ... global___ScalarScalaUDF = ScalarScalaUDF + +class JavaUDF(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLASS_NAME_FIELD_NUMBER: builtins.int + OUTPUT_TYPE_FIELD_NUMBER: builtins.int + AGGREGATE_FIELD_NUMBER: builtins.int + class_name: builtins.str + """(Required) Fully qualified name of Java class""" + output_type: builtins.str + """(Optional) Output type of the Java UDF""" + aggregate: builtins.bool + """(Required) Indicate if the Java user-defined function is an aggregate function""" + def __init__( + self, + *, + class_name: builtins.str = ..., + output_type: builtins.str | None = ..., + aggregate: builtins.bool = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_output_type", b"_output_type", "output_type", b"output_type" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_output_type", + b"_output_type", + "aggregate", + b"aggregate", + "class_name", + b"class_name", + "output_type", + b"output_type", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_output_type", b"_output_type"] + ) -> typing_extensions.Literal["output_type"] | None: ... + +global___JavaUDF = JavaUDF diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index c6bff4a3caab..03e53cbd89e7 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -127,9 +127,9 @@ def _build_common_inline_user_defined_function( ) return CommonInlineUserDefinedFunction( function_name=self._name, + function=py_udf, deterministic=self.deterministic, arguments=arg_exprs, - function=py_udf, ) def __call__(self, *cols: "ColumnOrName") -> Column: @@ -232,3 +232,18 @@ def register( return return_udf register.__doc__ = PySparkUDFRegistration.register.__doc__ + + def registerJavaFunction( + self, + name: str, + javaClassName: str, + returnType: Optional["DataTypeOrString"] = None, + ) -> None: + self.sparkSession._client.register_java(name, javaClassName, returnType) + + registerJavaFunction.__doc__ = PySparkUDFRegistration.registerJavaFunction.__doc__ + + def registerJavaUDAF(self, name: str, javaClassName: str) -> None: + self.sparkSession._client.register_java(name, javaClassName, aggregate=True) + + registerJavaUDAF.__doc__ = PySparkUDFRegistration.registerJavaUDAF.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py index 293f4b0f41a0..b38b4c28a25c 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf.py @@ -25,6 +25,7 @@ sql.udf.UserDefinedFunction = UserDefinedFunction +from pyspark.errors import AnalysisException from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.sql.types import IntegerType @@ -103,30 +104,13 @@ def test_udf3(self): def test_udf_registration_return_type_none(self): super().test_udf_registration_return_type_none() - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_non_existed_udaf(self): - super().test_non_existed_udaf() - - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") def test_non_existed_udf(self): - super().test_non_existed_udf() - - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_register_java_function(self): - super().test_register_java_function() - - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_register_java_udaf(self): - super().test_register_java_udaf() - - # TODO(SPARK-42210): implement `spark.udf` - @unittest.skip("Fails in Spark Connect, should enable.") - def test_udf_in_left_outer_join_condition(self): - super().test_udf_in_left_outer_join_condition() + spark = self.spark + self.assertRaisesRegex( + AnalysisException, + "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"), + ) def test_udf_registration_returns_udf(self): df = self.spark.range(10) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 9f8e3e469775..0b9b082ade3d 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -622,6 +622,9 @@ def registerJavaFunction( .. versionadded:: 2.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + Parameters ---------- name : str @@ -666,6 +669,9 @@ def registerJavaUDAF(self, name: str, javaClassName: str) -> None: .. versionadded:: 2.3.0 + .. versionchanged:: 3.4.0 + Support Spark Connect. + name : str name of the user-defined aggregate function javaClassName : str