diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto index 557b9db91237..b222f663cd0a 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -47,6 +47,7 @@ message Expression { UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14; CommonInlineUserDefinedFunction common_inline_user_defined_function = 15; CallFunction call_function = 16; + NamedArgumentExpression named_argument_expression = 17; // This field is used to mark extensions to the protocol. When plugins generate arbitrary // relations they can add them here. During the planning the correct resolution is done. @@ -380,3 +381,11 @@ message CallFunction { // (Optional) Function arguments. Empty arguments are allowed. repeated Expression arguments = 2; } + +message NamedArgumentExpression { + // (Required) The key of the named argument. + string key = 1; + + // (Required) The value expression of the named argument. + Expression value = 2; +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 45f962f79202..9bac48b50dbc 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1383,6 +1383,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction) case proto.Expression.ExprTypeCase.CALL_FUNCTION => transformCallFunction(exp.getCallFunction) + case proto.Expression.ExprTypeCase.NAMED_ARGUMENT_EXPRESSION => + transformNamedArgumentExpression(exp.getNamedArgumentExpression) case _ => throw InvalidPlanInput( s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported") @@ -1504,6 +1506,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { false) } + private def transformNamedArgumentExpression( + namedArg: proto.NamedArgumentExpression): Expression = { + NamedArgumentExpression(namedArg.getKey, transformExpression(namedArg.getValue)) + } + private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf) } diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 087cfaaa20b2..3a6d6e1cea76 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -73,9 +73,23 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject: return _to_java_column(col).expr() +@overload +def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject: + pass + + +@overload def _to_seq( sc: SparkContext, cols: Iterable["ColumnOrName"], + converter: Optional[Callable[["ColumnOrName"], JavaObject]], +) -> JavaObject: + pass + + +def _to_seq( + sc: SparkContext, + cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]], converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, ) -> JavaObject: """ diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index d0a9b1d69aee..34aa4da11179 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -1054,3 +1054,23 @@ def __repr__(self) -> str: return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})" else: return f"CallFunction('{self._name}')" + + +class NamedArgumentExpression(Expression): + def __init__(self, key: str, value: Expression): + super().__init__() + + assert isinstance(key, str) + self._key = key + + assert isinstance(value, Expression) + self._value = value + + def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": + expr = proto.Expression() + expr.named_argument_expression.key = self._key + expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session)) + return expr + + def __repr__(self) -> str: + return f"{self._key} => {self._value}" diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py index 51d1a5d48a16..51ad47bb1c8b 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.py +++ b/python/pyspark/sql/connect/proto/expressions_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xd9+\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\x9b\x0c\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lementsB\x0e\n\x0cliteral_type\x1ap\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1aR\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x42\x12\n\x10_unparsed_target\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xec\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdfB\n\n\x08\x66unction"\x9b\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targumentsB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbf,\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunction\x12Y\n\x11\x65xpression_string\x18\x04 \x01(\x0b\x32*.spark.connect.Expression.ExpressionStringH\x00R\x10\x65xpressionString\x12S\n\x0funresolved_star\x18\x05 \x01(\x0b\x32(.spark.connect.Expression.UnresolvedStarH\x00R\x0eunresolvedStar\x12\x37\n\x05\x61lias\x18\x06 \x01(\x0b\x32\x1f.spark.connect.Expression.AliasH\x00R\x05\x61lias\x12\x34\n\x04\x63\x61st\x18\x07 \x01(\x0b\x32\x1e.spark.connect.Expression.CastH\x00R\x04\x63\x61st\x12V\n\x10unresolved_regex\x18\x08 \x01(\x0b\x32).spark.connect.Expression.UnresolvedRegexH\x00R\x0funresolvedRegex\x12\x44\n\nsort_order\x18\t \x01(\x0b\x32#.spark.connect.Expression.SortOrderH\x00R\tsortOrder\x12S\n\x0flambda_function\x18\n \x01(\x0b\x32(.spark.connect.Expression.LambdaFunctionH\x00R\x0elambdaFunction\x12:\n\x06window\x18\x0b \x01(\x0b\x32 .spark.connect.Expression.WindowH\x00R\x06window\x12l\n\x18unresolved_extract_value\x18\x0c \x01(\x0b\x32\x30.spark.connect.Expression.UnresolvedExtractValueH\x00R\x16unresolvedExtractValue\x12M\n\rupdate_fields\x18\r \x01(\x0b\x32&.spark.connect.Expression.UpdateFieldsH\x00R\x0cupdateFields\x12\x82\x01\n unresolved_named_lambda_variable\x18\x0e \x01(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableH\x00R\x1dunresolvedNamedLambdaVariable\x12~\n#common_inline_user_defined_function\x18\x0f \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x1f\x63ommonInlineUserDefinedFunction\x12\x42\n\rcall_function\x18\x10 \x01(\x0b\x32\x1b.spark.connect.CallFunctionH\x00R\x0c\x63\x61llFunction\x12\x64\n\x19named_argument_expression\x18\x11 \x01(\x0b\x32&.spark.connect.NamedArgumentExpressionH\x00R\x17namedArgumentExpression\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x8f\x06\n\x06Window\x12\x42\n\x0fwindow_function\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0ewindowFunction\x12@\n\x0epartition_spec\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\rpartitionSpec\x12\x42\n\norder_spec\x18\x03 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\torderSpec\x12K\n\nframe_spec\x18\x04 \x01(\x0b\x32,.spark.connect.Expression.Window.WindowFrameR\tframeSpec\x1a\xed\x03\n\x0bWindowFrame\x12U\n\nframe_type\x18\x01 \x01(\x0e\x32\x36.spark.connect.Expression.Window.WindowFrame.FrameTypeR\tframeType\x12P\n\x05lower\x18\x02 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05lower\x12P\n\x05upper\x18\x03 \x01(\x0b\x32:.spark.connect.Expression.Window.WindowFrame.FrameBoundaryR\x05upper\x1a\x91\x01\n\rFrameBoundary\x12!\n\x0b\x63urrent_row\x18\x01 \x01(\x08H\x00R\ncurrentRow\x12\x1e\n\tunbounded\x18\x02 \x01(\x08H\x00R\tunbounded\x12\x31\n\x05value\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionH\x00R\x05valueB\n\n\x08\x62oundary"O\n\tFrameType\x12\x18\n\x14\x46RAME_TYPE_UNDEFINED\x10\x00\x12\x12\n\x0e\x46RAME_TYPE_ROW\x10\x01\x12\x14\n\x10\x46RAME_TYPE_RANGE\x10\x02\x1a\xa9\x03\n\tSortOrder\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12O\n\tdirection\x18\x02 \x01(\x0e\x32\x31.spark.connect.Expression.SortOrder.SortDirectionR\tdirection\x12U\n\rnull_ordering\x18\x03 \x01(\x0e\x32\x30.spark.connect.Expression.SortOrder.NullOrderingR\x0cnullOrdering"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"U\n\x0cNullOrdering\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x1a\x91\x01\n\x04\x43\x61st\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12-\n\x04type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04type\x12\x1b\n\x08type_str\x18\x03 \x01(\tH\x00R\x07typeStrB\x0e\n\x0c\x63\x61st_to_type\x1a\x9b\x0c\n\x07Literal\x12-\n\x04null\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x04null\x12\x18\n\x06\x62inary\x18\x02 \x01(\x0cH\x00R\x06\x62inary\x12\x1a\n\x07\x62oolean\x18\x03 \x01(\x08H\x00R\x07\x62oolean\x12\x14\n\x04\x62yte\x18\x04 \x01(\x05H\x00R\x04\x62yte\x12\x16\n\x05short\x18\x05 \x01(\x05H\x00R\x05short\x12\x1a\n\x07integer\x18\x06 \x01(\x05H\x00R\x07integer\x12\x14\n\x04long\x18\x07 \x01(\x03H\x00R\x04long\x12\x16\n\x05\x66loat\x18\n \x01(\x02H\x00R\x05\x66loat\x12\x18\n\x06\x64ouble\x18\x0b \x01(\x01H\x00R\x06\x64ouble\x12\x45\n\x07\x64\x65\x63imal\x18\x0c \x01(\x0b\x32).spark.connect.Expression.Literal.DecimalH\x00R\x07\x64\x65\x63imal\x12\x18\n\x06string\x18\r \x01(\tH\x00R\x06string\x12\x14\n\x04\x64\x61te\x18\x10 \x01(\x05H\x00R\x04\x64\x61te\x12\x1e\n\ttimestamp\x18\x11 \x01(\x03H\x00R\ttimestamp\x12%\n\rtimestamp_ntz\x18\x12 \x01(\x03H\x00R\x0ctimestampNtz\x12\x61\n\x11\x63\x61lendar_interval\x18\x13 \x01(\x0b\x32\x32.spark.connect.Expression.Literal.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12\x30\n\x13year_month_interval\x18\x14 \x01(\x05H\x00R\x11yearMonthInterval\x12,\n\x11\x64\x61y_time_interval\x18\x15 \x01(\x03H\x00R\x0f\x64\x61yTimeInterval\x12?\n\x05\x61rray\x18\x16 \x01(\x0b\x32\'.spark.connect.Expression.Literal.ArrayH\x00R\x05\x61rray\x12\x39\n\x03map\x18\x17 \x01(\x0b\x32%.spark.connect.Expression.Literal.MapH\x00R\x03map\x12\x42\n\x06struct\x18\x18 \x01(\x0b\x32(.spark.connect.Expression.Literal.StructH\x00R\x06struct\x1au\n\x07\x44\x65\x63imal\x12\x14\n\x05value\x18\x01 \x01(\tR\x05value\x12!\n\tprecision\x18\x02 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x19\n\x05scale\x18\x03 \x01(\x05H\x01R\x05scale\x88\x01\x01\x42\x0c\n\n_precisionB\x08\n\x06_scale\x1a\x62\n\x10\x43\x61lendarInterval\x12\x16\n\x06months\x18\x01 \x01(\x05R\x06months\x12\x12\n\x04\x64\x61ys\x18\x02 \x01(\x05R\x04\x64\x61ys\x12"\n\x0cmicroseconds\x18\x03 \x01(\x03R\x0cmicroseconds\x1a\x82\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lements\x1a\xe3\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12\x35\n\x04keys\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x04keys\x12\x39\n\x06values\x18\x04 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1a\x81\x01\n\x06Struct\x12\x38\n\x0bstruct_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\nstructType\x12=\n\x08\x65lements\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x08\x65lementsB\x0e\n\x0cliteral_type\x1ap\n\x13UnresolvedAttribute\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\xcc\x01\n\x12UnresolvedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x1f\n\x0bis_distinct\x18\x03 \x01(\x08R\nisDistinct\x12\x37\n\x18is_user_defined_function\x18\x04 \x01(\x08R\x15isUserDefinedFunction\x1a\x32\n\x10\x45xpressionString\x12\x1e\n\nexpression\x18\x01 \x01(\tR\nexpression\x1aR\n\x0eUnresolvedStar\x12,\n\x0funparsed_target\x18\x01 \x01(\tH\x00R\x0eunparsedTarget\x88\x01\x01\x42\x12\n\x10_unparsed_target\x1aV\n\x0fUnresolvedRegex\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x42\n\n\x08_plan_id\x1a\x84\x01\n\x16UnresolvedExtractValue\x12/\n\x05\x63hild\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05\x63hild\x12\x39\n\nextraction\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nextraction\x1a\xbb\x01\n\x0cUpdateFields\x12\x46\n\x11struct_expression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x10structExpression\x12\x1d\n\nfield_name\x18\x02 \x01(\tR\tfieldName\x12\x44\n\x10value_expression\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0fvalueExpression\x1ax\n\x05\x41lias\x12-\n\x04\x65xpr\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x04\x65xpr\x12\x12\n\x04name\x18\x02 \x03(\tR\x04name\x12\x1f\n\x08metadata\x18\x03 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x9e\x01\n\x0eLambdaFunction\x12\x35\n\x08\x66unction\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08\x66unction\x12U\n\targuments\x18\x02 \x03(\x0b\x32\x37.spark.connect.Expression.UnresolvedNamedLambdaVariableR\targuments\x1a>\n\x1dUnresolvedNamedLambdaVariable\x12\x1d\n\nname_parts\x18\x01 \x03(\tR\tnamePartsB\x0b\n\texpr_type"\xec\x02\n\x1f\x43ommonInlineUserDefinedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12\x39\n\npython_udf\x18\x04 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\tpythonUdf\x12I\n\x10scalar_scala_udf\x18\x05 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\x0escalarScalaUdf\x12\x33\n\x08java_udf\x18\x06 \x01(\x0b\x32\x16.spark.connect.JavaUDFH\x00R\x07javaUdfB\n\n\x08\x66unction"\x9b\x01\n\tPythonUDF\x12\x38\n\x0boutput_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVer"\xb8\x01\n\x0eScalarScalaUDF\x12\x18\n\x07payload\x18\x01 \x01(\x0cR\x07payload\x12\x37\n\ninputTypes\x18\x02 \x03(\x0b\x32\x17.spark.connect.DataTypeR\ninputTypes\x12\x37\n\noutputType\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeR\noutputType\x12\x1a\n\x08nullable\x18\x04 \x01(\x08R\x08nullable"\x95\x01\n\x07JavaUDF\x12\x1d\n\nclass_name\x18\x01 \x01(\tR\tclassName\x12=\n\x0boutput_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\noutputType\x88\x01\x01\x12\x1c\n\taggregate\x18\x03 \x01(\x08R\taggregateB\x0e\n\x0c_output_type"l\n\x0c\x43\x61llFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\\\n\x17NamedArgumentExpression\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05valueB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -45,63 +45,65 @@ b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated" ) _EXPRESSION._serialized_start = 105 - _EXPRESSION._serialized_end = 5698 - _EXPRESSION_WINDOW._serialized_start = 1543 - _EXPRESSION_WINDOW._serialized_end = 2326 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1833 - _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2326 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2100 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2245 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2247 - _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2326 - _EXPRESSION_SORTORDER._serialized_start = 2329 - _EXPRESSION_SORTORDER._serialized_end = 2754 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2559 - _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2667 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2669 - _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2754 - _EXPRESSION_CAST._serialized_start = 2757 - _EXPRESSION_CAST._serialized_end = 2902 - _EXPRESSION_LITERAL._serialized_start = 2905 - _EXPRESSION_LITERAL._serialized_end = 4468 - _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3740 - _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3857 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3859 - _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3957 - _EXPRESSION_LITERAL_ARRAY._serialized_start = 3960 - _EXPRESSION_LITERAL_ARRAY._serialized_end = 4090 - _EXPRESSION_LITERAL_MAP._serialized_start = 4093 - _EXPRESSION_LITERAL_MAP._serialized_end = 4320 - _EXPRESSION_LITERAL_STRUCT._serialized_start = 4323 - _EXPRESSION_LITERAL_STRUCT._serialized_end = 4452 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4470 - _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4582 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4585 - _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4789 - _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4791 - _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4841 - _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4843 - _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4925 - _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4927 - _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5013 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5016 - _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5148 - _EXPRESSION_UPDATEFIELDS._serialized_start = 5151 - _EXPRESSION_UPDATEFIELDS._serialized_end = 5338 - _EXPRESSION_ALIAS._serialized_start = 5340 - _EXPRESSION_ALIAS._serialized_end = 5460 - _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5463 - _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5621 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5623 - _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5685 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5701 - _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6065 - _PYTHONUDF._serialized_start = 6068 - _PYTHONUDF._serialized_end = 6223 - _SCALARSCALAUDF._serialized_start = 6226 - _SCALARSCALAUDF._serialized_end = 6410 - _JAVAUDF._serialized_start = 6413 - _JAVAUDF._serialized_end = 6562 - _CALLFUNCTION._serialized_start = 6564 - _CALLFUNCTION._serialized_end = 6672 + _EXPRESSION._serialized_end = 5800 + _EXPRESSION_WINDOW._serialized_start = 1645 + _EXPRESSION_WINDOW._serialized_end = 2428 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1935 + _EXPRESSION_WINDOW_WINDOWFRAME._serialized_end = 2428 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_start = 2202 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMEBOUNDARY._serialized_end = 2347 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_start = 2349 + _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE._serialized_end = 2428 + _EXPRESSION_SORTORDER._serialized_start = 2431 + _EXPRESSION_SORTORDER._serialized_end = 2856 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_start = 2661 + _EXPRESSION_SORTORDER_SORTDIRECTION._serialized_end = 2769 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_start = 2771 + _EXPRESSION_SORTORDER_NULLORDERING._serialized_end = 2856 + _EXPRESSION_CAST._serialized_start = 2859 + _EXPRESSION_CAST._serialized_end = 3004 + _EXPRESSION_LITERAL._serialized_start = 3007 + _EXPRESSION_LITERAL._serialized_end = 4570 + _EXPRESSION_LITERAL_DECIMAL._serialized_start = 3842 + _EXPRESSION_LITERAL_DECIMAL._serialized_end = 3959 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3961 + _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 4059 + _EXPRESSION_LITERAL_ARRAY._serialized_start = 4062 + _EXPRESSION_LITERAL_ARRAY._serialized_end = 4192 + _EXPRESSION_LITERAL_MAP._serialized_start = 4195 + _EXPRESSION_LITERAL_MAP._serialized_end = 4422 + _EXPRESSION_LITERAL_STRUCT._serialized_start = 4425 + _EXPRESSION_LITERAL_STRUCT._serialized_end = 4554 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 4572 + _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 4684 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 4687 + _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4891 + _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4893 + _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4943 + _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4945 + _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 5027 + _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 5029 + _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 5115 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 5118 + _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 5250 + _EXPRESSION_UPDATEFIELDS._serialized_start = 5253 + _EXPRESSION_UPDATEFIELDS._serialized_end = 5440 + _EXPRESSION_ALIAS._serialized_start = 5442 + _EXPRESSION_ALIAS._serialized_end = 5562 + _EXPRESSION_LAMBDAFUNCTION._serialized_start = 5565 + _EXPRESSION_LAMBDAFUNCTION._serialized_end = 5723 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5725 + _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5787 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5803 + _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 6167 + _PYTHONUDF._serialized_start = 6170 + _PYTHONUDF._serialized_end = 6325 + _SCALARSCALAUDF._serialized_start = 6328 + _SCALARSCALAUDF._serialized_end = 6512 + _JAVAUDF._serialized_start = 6515 + _JAVAUDF._serialized_end = 6664 + _CALLFUNCTION._serialized_start = 6666 + _CALLFUNCTION._serialized_end = 6774 + _NAMEDARGUMENTEXPRESSION._serialized_start = 6776 + _NAMEDARGUMENTEXPRESSION._serialized_end = 6868 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index b9b16ce35e3f..2b418ef23f60 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -1102,6 +1102,7 @@ class Expression(google.protobuf.message.Message): UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int CALL_FUNCTION_FIELD_NUMBER: builtins.int + NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int @property def literal(self) -> global___Expression.Literal: ... @@ -1138,6 +1139,8 @@ class Expression(google.protobuf.message.Message): @property def call_function(self) -> global___CallFunction: ... @property + def named_argument_expression(self) -> global___NamedArgumentExpression: ... + @property def extension(self) -> google.protobuf.any_pb2.Any: """This field is used to mark extensions to the protocol. When plugins generate arbitrary relations they can add them here. During the planning the correct resolution is done. @@ -1162,6 +1165,7 @@ class Expression(google.protobuf.message.Message): | None = ..., common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ..., call_function: global___CallFunction | None = ..., + named_argument_expression: global___NamedArgumentExpression | None = ..., extension: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... def HasField( @@ -1185,6 +1189,8 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", + "named_argument_expression", + b"named_argument_expression", "sort_order", b"sort_order", "unresolved_attribute", @@ -1226,6 +1232,8 @@ class Expression(google.protobuf.message.Message): b"lambda_function", "literal", b"literal", + "named_argument_expression", + b"named_argument_expression", "sort_order", b"sort_order", "unresolved_attribute", @@ -1265,6 +1273,7 @@ class Expression(google.protobuf.message.Message): "unresolved_named_lambda_variable", "common_inline_user_defined_function", "call_function", + "named_argument_expression", "extension", ] | None: ... @@ -1505,3 +1514,28 @@ class CallFunction(google.protobuf.message.Message): ) -> None: ... global___CallFunction = CallFunction + +class NamedArgumentExpression(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + """(Required) The key of the named argument.""" + @property + def value(self) -> global___Expression: + """(Required) The value expression of the named argument.""" + def __init__( + self, + *, + key: builtins.str = ..., + value: global___Expression | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["value", b"value"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + +global___NamedArgumentExpression = NamedArgumentExpression diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index c8495626292c..ce37832854cd 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -22,11 +22,11 @@ check_dependencies(__name__) import warnings -from typing import Type, TYPE_CHECKING, Optional, Union +from typing import List, Type, TYPE_CHECKING, Optional, Union from pyspark.rdd import PythonEvalType from pyspark.sql.connect.column import Column -from pyspark.sql.connect.expressions import ColumnReference +from pyspark.sql.connect.expressions import ColumnReference, Expression, NamedArgumentExpression from pyspark.sql.connect.plan import ( CommonInlineUserDefinedTableFunction, PythonUDTF, @@ -146,12 +146,14 @@ def __init__( self.deterministic = deterministic def _build_common_inline_user_defined_table_function( - self, *cols: "ColumnOrName" + self, *args: "ColumnOrName", **kwargs: "ColumnOrName" ) -> CommonInlineUserDefinedTableFunction: - arg_cols = [ - col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols + def to_expr(col: "ColumnOrName") -> Expression: + return col._expr if isinstance(col, Column) else ColumnReference(col) + + arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [ + NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items() ] - arg_exprs = [col._expr for col in arg_cols] udtf = PythonUDTF( func=self.func, @@ -166,13 +168,13 @@ def _build_common_inline_user_defined_table_function( arguments=arg_exprs, ) - def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame session = SparkSession.active() - plan = self._build_common_inline_user_defined_table_function(*cols) + plan = self._build_common_inline_user_defined_table_function(*args, **kwargs) return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fdb4ec8111ed..9cc364cc1f8c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -15547,6 +15547,12 @@ def udtf( .. versionadded:: 3.5.0 + .. versionchanged:: 4.0.0 + Supports Python side analysis. + + .. versionchanged:: 4.0.0 + Supports keyword-arguments. + Parameters ---------- cls : class @@ -15623,6 +15629,38 @@ def udtf( | 1| x| +---+---+ + UDTF can use keyword arguments: + + >>> @udtf + ... class TestUDTFWithKwargs: + ... @staticmethod + ... def analyze( + ... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs: AnalyzeArgument + ... ) -> AnalyzeResult: + ... return AnalyzeResult( + ... StructType().add("a", a.data_type) + ... .add("b", b.data_type) + ... .add("x", kwargs["x"].data_type) + ... ) + ... + ... def eval(self, a, b, **kwargs): + ... yield a, b, kwargs["x"] + ... + >>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show() + +---+---+---+ + | a| b| x| + +---+---+---+ + | 1| b| x| + +---+---+---+ + + >>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs) + >>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show() + +---+---+---+ + | a| b| x| + +---+---+---+ + | 1| b| x| + +---+---+---+ + Arrow optimization can be explicitly enabled when creating UDTFs: >>> @udtf(returnType="c1: int, c2: int", useArrow=True) diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 300067716e9d..cd0604ccacee 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1565,6 +1565,7 @@ def eval(self, **kwargs): expected = [Row(c1="hello", c2="world")] assertDataFrameEqual(TestUDTF(), expected) assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected) + assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected) with self.assertRaisesRegex( AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given" @@ -1795,6 +1796,93 @@ def terminate(self): assertSchemaEqual(df.schema, StructType().add("col1", IntegerType())) assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)]) + def test_udtf_with_named_arguments(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a, b): + yield a, + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10)]) + + def test_udtf_with_named_arguments_negative(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self, a, b): + yield a, + + self.spark.udtf.register("test_udtf", TestUDTF) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show() + + with self.assertRaisesRegex( + PythonException, r"eval\(\) got an unexpected keyword argument 'c'" + ): + self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show() + + def test_udtf_with_kwargs(self): + @udtf(returnType="a: int, b: string") + class TestUDTF: + def eval(self, **kwargs): + yield kwargs["a"], kwargs["b"] + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b="x")]) + + def test_udtf_with_analyze_kwargs(self): + @udtf + class TestUDTF: + @staticmethod + def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult: + return AnalyzeResult( + StructType( + [StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())] + ) + ) + + def eval(self, **kwargs): + yield tuple(value for _, value in sorted(kwargs.items())) + + self.spark.udtf.register("test_udtf", TestUDTF) + + for i, df in enumerate( + [ + self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"), + self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"), + TestUDTF(a=lit(10), b=lit("x")), + TestUDTF(b=lit("x"), a=lit(10)), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(df, [Row(a=10, b="x")]) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 027a2646a465..1ca87aae758d 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -29,7 +29,7 @@ from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType -from pyspark.sql.column import _to_java_column, _to_seq +from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType, StructType, _parse_datatype_string from pyspark.sql.udf import _wrap_function @@ -148,9 +148,9 @@ def _vectorize_udtf(cls: Type) -> Type: # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: @wraps(f) - def evaluate(*a: Any) -> Any: + def evaluate(*a: Any, **kw: Any) -> Any: try: - return f(*a) + return f(*a, **kw) except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -168,18 +168,22 @@ def __init__(self) -> None: ): @staticmethod - def analyze(*args: AnalyzeArgument) -> AnalyzeResult: - return cls.analyze(*args) + def analyze(*args: AnalyzeArgument, **kwargs: AnalyzeArgument) -> AnalyzeResult: + return cls.analyze(*args, **kwargs) - def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: - if len(args) == 0: + def eval(self, *args: pd.Series, **kwargs: pd.Series) -> Iterator[pd.DataFrame]: + if len(args) == 0 and len(kwargs) == 0: yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. - row_tuples = zip(*args) + keys = list(kwargs.keys()) + len_args = len(args) + row_tuples = zip(*args, *[kwargs[key] for key in keys]) for row in row_tuples: - res = wrap_func(self.func.eval)(*row) + res = wrap_func(self.func.eval)( + *row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)} + ) if res is not None and not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", @@ -339,14 +343,24 @@ def _create_judtf(self, func: Type) -> JavaObject: ) return judtf - def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": from pyspark.sql import DataFrame, SparkSession spark = SparkSession._getActiveSessionOrCreate() sc = spark.sparkContext + assert sc._jvm is not None + jcols = [_to_java_column(arg) for arg in args] + [ + sc._jvm.Column( + sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( + key, _to_java_expr(value) + ) + ) + for key, value in kwargs.items() + ] + judtf = self._judtf - jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column)) + jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, jcols)) return DataFrame(jPythonUDTF, spark) def asNondeterministic(self) -> "UserDefinedTableFunction": diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index 9ffa03541e69..7ba0789fa7b7 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -19,7 +19,7 @@ import os import sys import traceback -from typing import List, IO +from typing import Dict, List, IO, Tuple from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkRuntimeError, PySparkValueError @@ -69,11 +69,12 @@ def read_udtf(infile: IO) -> type: return handler -def read_arguments(infile: IO) -> List[AnalyzeArgument]: +def read_arguments(infile: IO) -> Tuple[List[AnalyzeArgument], Dict[str, AnalyzeArgument]]: """Reads the arguments for `analyze` static method.""" # Receive arguments num_args = read_int(infile) args: List[AnalyzeArgument] = [] + kwargs: Dict[str, AnalyzeArgument] = {} for _ in range(num_args): dt = _parse_datatype_json_string(utf8_deserializer.loads(infile)) if read_bool(infile): # is foldable @@ -83,8 +84,15 @@ def read_arguments(infile: IO) -> List[AnalyzeArgument]: else: value = None is_table = read_bool(infile) # is table argument - args.append(AnalyzeArgument(data_type=dt, value=value, is_table=is_table)) - return args + argument = AnalyzeArgument(data_type=dt, value=value, is_table=is_table) + + is_named_arg = read_bool(infile) + if is_named_arg: + name = utf8_deserializer.loads(infile) + kwargs[name] = argument + else: + args.append(argument) + return args, kwargs def main(infile: IO, outfile: IO) -> None: @@ -107,9 +115,9 @@ def main(infile: IO, outfile: IO) -> None: _accumulatorRegistry.clear() handler = read_udtf(infile) - args = read_arguments(infile) + args, kwargs = read_arguments(infile) - result = handler.analyze(*args) # type: ignore[attr-defined] + result = handler.analyze(*args, **kwargs) # type: ignore[attr-defined] if not isinstance(result, AnalyzeResult): raise PySparkValueError( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6f27400387e7..8916a794001c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -550,7 +550,16 @@ def read_udtf(pickleSer, infile, eval_type): # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand' num_arg = read_int(infile) - arg_offsets = [read_int(infile) for _ in range(num_arg)] + args_offsets = [] + kwargs_offsets = {} + for _ in range(num_arg): + offset = read_int(infile) + if read_bool(infile): + name = utf8_deserializer.loads(infile) + kwargs_offsets[name] = offset + else: + args_offsets.append(offset) + handler = read_command(pickleSer, infile) if not isinstance(handler, type): raise PySparkRuntimeError( @@ -619,7 +628,9 @@ def verify_result(result): ) return result - return lambda *a: map(lambda res: (res, arrow_return_type), map(verify_result, f(*a))) + return lambda *a, **kw: map( + lambda res: (res, arrow_return_type), map(verify_result, f(*a, **kw)) + ) eval = wrap_arrow_udtf(getattr(udtf, "eval"), return_type) @@ -633,7 +644,10 @@ def mapper(_, it): for a in it: # The eval function yields an iterator. Each element produced by this # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). - yield from eval(*[a[o] for o in arg_offsets]) + yield from eval( + *[a[o] for o in args_offsets], + **{k: a[o] for k, o in kwargs_offsets.items()}, + ) finally: if terminate is not None: yield from terminate() @@ -667,9 +681,9 @@ def verify_and_convert_result(result): return toInternal(result) # Evaluate the function and return a tuple back to the executor. - def evaluate(*a) -> tuple: + def evaluate(*a, **kw) -> tuple: try: - res = f(*a) + res = f(*a, **kw) except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -705,7 +719,10 @@ def evaluate(*a) -> tuple: def mapper(_, it): try: for a in it: - yield eval(*[a[o] for o in arg_offsets]) + yield eval( + *[a[o] for o in args_offsets], + **{k: a[o] for k, o in kwargs_offsets.items()}, + ) finally: if terminate is not None: yield terminate() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 1088655f60cd..d13bfab6d702 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -72,6 +72,38 @@ trait FunctionBuilderBase[T] { } object NamedParametersSupport { + /** + * This method splits named arguments from the argument list. + * Also checks if: + * - the named arguments don't contains positional arguments once keyword arguments start + * - the named arguments don't use the duplicated names + * + * @param functionSignature The function signature that defines the positional ordering + * @param args The argument list provided in function invocation + * @return A tuple of a list of positional arguments and a list of keyword arguments + */ + def splitAndCheckNamedArguments( + args: Seq[Expression], + functionName: String): (Seq[Expression], Seq[NamedArgumentExpression]) = { + val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) + + val namedParametersSet = collection.mutable.Set[String]() + + (positionalArgs, + namedArgs.zipWithIndex.map { + case (namedArg @ NamedArgumentExpression(parameterName, _), _) => + if (namedParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.doubleNamedArgumentReference( + functionName, parameterName) + } + namedParametersSet.add(parameterName) + namedArg + case (_, index) => + throw QueryCompilationErrors.unexpectedPositionalArgument( + functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) + }) + } + /** * This method is the default routine which rearranges the arguments in positional order according * to the function signature provided. This will also fill in any default values that exists for @@ -93,7 +125,7 @@ object NamedParametersSupport { functionName, functionSignature) } - val (positionalArgs, namedArgs) = args.span(!_.isInstanceOf[NamedArgumentExpression]) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, functionName) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -102,28 +134,16 @@ object NamedParametersSupport { val allParameterNames: Seq[String] = parameters.map(_.name) val parameterNamesSet: Set[String] = allParameterNames.toSet val positionalParametersSet = allParameterNames.take(positionalArgs.size).toSet - val namedParametersSet = collection.mutable.Set[String]() - namedArgs.zipWithIndex.foreach { case (arg, index) => - arg match { - case namedArg: NamedArgumentExpression => - val parameterName = namedArg.key - if (!parameterNamesSet.contains(parameterName)) { - throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, - parameterNamesSet.toSeq) - } - if (positionalParametersSet.contains(parameterName)) { - throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( - functionName, namedArg.key) - } - if (namedParametersSet.contains(parameterName)) { - throw QueryCompilationErrors.doubleNamedArgumentReference( - functionName, namedArg.key) - } - namedParametersSet.add(namedArg.key) - case _ => - throw QueryCompilationErrors.unexpectedPositionalArgument( - functionName, namedArgs(index - 1).asInstanceOf[NamedArgumentExpression].key) + namedArgs.foreach { namedArg => + val parameterName = namedArg.key + if (!parameterNamesSet.contains(parameterName)) { + throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + parameterNamesSet.toSeq) + } + if (positionalParametersSet.contains(parameterName)) { + throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( + functionName, namedArg.key) } } @@ -136,8 +156,7 @@ object NamedParametersSupport { } // This constructs a map from argument name to value for argument rearrangement. - val namedArgMap = namedArgs.map { arg => - val namedArg = arg.asInstanceOf[NamedArgumentExpression] + val namedArgMap = namedArgs.map { namedArg => namedArg.key -> namedArg.value }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 9c0addfd2ae3..8ebd8a3a106c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} @@ -52,7 +53,7 @@ case class ArrowEvalPythonUDTFExec( private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] = { @@ -64,7 +65,7 @@ case class ArrowEvalPythonUDTFExec( val columnarBatchIter = new ArrowPythonUDTFRunner( udtf, evalType, - argOffsets, + argMetas, schema, sessionLocalTimeZone, largeVarTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 1dd06c2dc73a..c0fa8b58bee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -23,6 +23,7 @@ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDTF import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -33,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class ArrowPythonUDTFRunner( udtf: PythonUDTF, evalType: Int, - offsets: Array[Int], + argMetas: Array[ArgumentMetadata], protected override val schema: StructType, protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, @@ -41,7 +42,8 @@ class ArrowPythonUDTFRunner( val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(offsets), jobArtifactUUID) + Seq(ChainedPythonFunctions(Seq(udtf.func))), + evalType, Array(argMetas.map(_.offset)), jobArtifactUUID) with BasicPythonArrowInput with BasicPythonArrowOutput { @@ -49,7 +51,7 @@ class ArrowPythonUDTFRunner( dataOut: DataOutputStream, funcs: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]]): Unit = { - PythonUDTFRunner.writeUDTF(dataOut, udtf, offsets) + PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } override val pythonExec: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index 6c8412f8b377..cbc90f34a37e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.StructType /** @@ -55,7 +56,7 @@ case class BatchEvalPythonUDTFExec( * an iterator of internal rows for every input row. */ override protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] = { @@ -66,7 +67,7 @@ case class BatchEvalPythonUDTFExec( // Output iterator for results from Python. val outputIterator = - new PythonUDTFRunner(udtf, argOffsets, pythonMetrics, jobArtifactUUID) + new PythonUDTFRunner(udtf, argMetas, pythonMetrics, jobArtifactUUID) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -93,12 +94,12 @@ case class BatchEvalPythonUDTFExec( class PythonUDTFRunner( udtf: PythonUDTF, - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonUDFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), - PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, jobArtifactUUID) { + PythonEvalType.SQL_TABLE_UDF, Array(argMetas.map(_.offset)), pythonMetrics, jobArtifactUUID) { protected override def newWriter( env: SparkEnv, @@ -109,7 +110,7 @@ class PythonUDTFRunner( new PythonUDFWriter(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - PythonUDTFRunner.writeUDTF(dataOut, udtf, argOffsets) + PythonUDTFRunner.writeUDTF(dataOut, udtf, argMetas) } } } @@ -117,10 +118,21 @@ class PythonUDTFRunner( object PythonUDTFRunner { - def writeUDTF(dataOut: DataOutputStream, udtf: PythonUDTF, argOffsets: Array[Int]): Unit = { - dataOut.writeInt(argOffsets.length) - argOffsets.foreach { offset => - dataOut.writeInt(offset) + def writeUDTF( + dataOut: DataOutputStream, + udtf: PythonUDTF, + argMetas: Array[ArgumentMetadata]): Unit = { + dataOut.writeInt(argMetas.length) + argMetas.foreach { + case ArgumentMetadata(offset, name) => + dataOut.writeInt(offset) + name match { + case Some(name) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(name, dataOut) + case _ => + dataOut.writeBoolean(false) + } } dataOut.writeInt(udtf.func.command.length) dataOut.write(udtf.func.command.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala index fab417a0f86f..410209e0adad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonUDTFExec.scala @@ -26,9 +26,20 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.execution.python.EvalPythonUDTFExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils +object EvalPythonUDTFExec { + /** + * Metadata for arguments of Python UDTF. + * + * @param offset the offset of the argument + * @param name the name of the argument if it's a `NamedArgumentExpression` + */ + case class ArgumentMetadata(offset: Int, name: Option[String]) +} + /** * A physical plan that evaluates a [[PythonUDTF]], one partition of tuples at a time. * This is similar to [[EvalPythonExec]]. @@ -45,7 +56,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) protected def evaluate( - argOffsets: Array[Int], + argMetas: Array[ArgumentMetadata], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[Iterator[InternalRow]] @@ -68,13 +79,19 @@ trait EvalPythonUDTFExec extends UnaryExecNode { // flatten all the arguments val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = udtf.children.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + val argMetas = udtf.children.map { e => + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (allInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + allInputs += value + dataTypes += value.dataType + ArgumentMetadata(allInputs.length - 1, key) } }.toArray val projection = MutableProjection.create(allInputs.toSeq, child.output) @@ -93,7 +110,7 @@ trait EvalPythonUDTFExec extends UnaryExecNode { projection(inputRow) } - val outputRowIterator = evaluate(argOffsets, projectedRowIter, schema, context) + val outputRowIterator = evaluate(argMetas, projectedRowIter, schema, context) val pruneChildForResult: InternalRow => InternalRow = if (child.outputSet == AttributeSet(requiredChildOutput)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index 5fa9c89b3d15..38d521c16d59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.config.BUFFER_SIZE import org.apache.spark.internal.config.Python._ import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, PythonUDAF, PythonUDF, PythonUDTF, UnresolvedPolymorphicPythonUDTF} -import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -101,6 +101,13 @@ case class UserDefinedPythonTableFunction( } def builder(exprs: Seq[Expression]): LogicalPlan = { + /* + * Check if the named arguments: + * - don't have duplicated names + * - don't contain positional arguments + */ + NamedParametersSupport.splitAndCheckNamedArguments(exprs, name) + val udtf = returnType match { case Some(rt) => PythonUDTF( @@ -213,8 +220,6 @@ object UserDefinedPythonTableFunction { val bufferStream = new DirectByteBufferOutputStream() try { val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize)) - val dataIn = new DataInputStream(new BufferedInputStream( - new WorkerInputStream(worker, bufferStream), bufferSize)) PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut) @@ -237,11 +242,22 @@ object UserDefinedPythonTableFunction { dataOut.writeBoolean(false) } dataOut.writeBoolean(is_table) + // If the expr is NamedArgumentExpression, send its name. + expr match { + case NamedArgumentExpression(key, _) => + dataOut.writeBoolean(true) + PythonWorkerUtils.writeUTF(key, dataOut) + case _ => + dataOut.writeBoolean(false) + } } dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() + val dataIn = new DataInputStream(new BufferedInputStream( + new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize)) + // Receive the schema val schema = dataIn.readInt() match { case length if length >= 0 => @@ -273,9 +289,13 @@ object UserDefinedPythonTableFunction { case eof: EOFException => throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } finally { - if (!releasedOrClosed) { - // An error happened. Force to close the worker. - env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + try { + bufferStream.close() + } finally { + if (!releasedOrClosed) { + // An error happened. Force to close the worker. + env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } } } } @@ -288,8 +308,7 @@ object UserDefinedPythonTableFunction { * This is a port and simplified version of `PythonRunner.ReaderInputStream`, * and only supports to write all at once and then read all. */ - private class WorkerInputStream( - worker: PythonWorker, bufferStream: DirectByteBufferOutputStream) extends InputStream { + private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream { private[this] val temp = new Array[Byte](1) @@ -312,14 +331,15 @@ object UserDefinedPythonTableFunction { n = worker.channel.read(buf) } if (worker.selectionKey.isWritable) { - val buffer = bufferStream.toByteBuffer var acceptsInput = true while (acceptsInput && buffer.hasRemaining) { val n = worker.channel.write(buffer) acceptsInput = n > 0 } - // We no longer have any data to write to the socket. - worker.selectionKey.interestOps(SelectionKey.OP_READ) + if (!buffer.hasRemaining) { + // We no longer have any data to write to the socket. + worker.selectionKey.interestOps(SelectionKey.OP_READ) + } } } n