diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 618720ef4931..353fbebd0460 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -37,11 +37,12 @@ message Relation { Join join = 5; Union union = 6; Sort sort = 7; - Fetch fetch = 8; + Limit limit = 8; Aggregate aggregate = 9; SQL sql = 10; LocalRelation local_relation = 11; Sample sample = 12; + Offset offset = 13; Unknown unknown = 999; } @@ -121,11 +122,17 @@ message Union { } } -// Relation of type [[Fetch]] that is used to read `limit` / `offset` rows from the input relation. -message Fetch { +// Relation of type [[Limit]] that is used to `limit` rows from the input relation. +message Limit { Relation input = 1; int32 limit = 2; - int32 offset = 3; +} + +// Relation of type [[Offset]] that is used to read rows staring from the `offset` on +// the input relation. +message Offset { + Relation input = 1; + int32 offset = 2; } // Relation of type [[Aggregate]]. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 7b8b58e1abac..8a267dff7d78 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -186,6 +186,28 @@ package object dsl { .build() } + def limit(limit: Int): proto.Relation = { + proto.Relation + .newBuilder() + .setLimit( + proto.Limit + .newBuilder() + .setInput(logicalPlan) + .setLimit(limit)) + .build() + } + + def offset(offset: Int): proto.Relation = { + proto.Relation + .newBuilder() + .setOffset( + proto.Offset + .newBuilder() + .setInput(logicalPlan) + .setOffset(offset)) + .build() + } + def where(condition: proto.Expression): proto.Relation = { proto.Relation .newBuilder() diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a5606f278f90..6a6b5a15a087 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -55,7 +55,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead, common) case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject, common) case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter) - case proto.Relation.RelTypeCase.FETCH => transformFetch(rel.getFetch) + case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit) + case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset) case proto.Relation.RelTypeCase.JOIN => transformJoin(rel.getJoin) case proto.Relation.RelTypeCase.UNION => transformUnion(rel.getUnion) case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort) @@ -194,10 +195,16 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } } - private def transformFetch(limit: proto.Fetch): LogicalPlan = { + private def transformLimit(limit: proto.Limit): LogicalPlan = { logical.Limit( - child = transformRelation(limit.getInput), - limitExpr = expressions.Literal(limit.getLimit, IntegerType)) + limitExpr = expressions.Literal(limit.getLimit, IntegerType), + transformRelation(limit.getInput)) + } + + private def transformOffset(offset: proto.Offset): LogicalPlan = { + logical.Offset( + offsetExpr = expressions.Literal(offset.getOffset, IntegerType), + transformRelation(offset.getInput)) } /** diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index ef73eb8d21e2..fc3d219ec6ba 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -77,7 +77,9 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { test("Simple Limit") { assertThrows[IndexOutOfBoundsException] { new SparkConnectPlanner( - proto.Relation.newBuilder.setFetch(proto.Fetch.newBuilder.setLimit(10).build()).build(), + proto.Relation.newBuilder + .setLimit(proto.Limit.newBuilder.setLimit(10)) + .build(), None.orNull) .transform() } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index ef4b358798ec..d8bb1684cb84 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -167,6 +167,36 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } + test("Test limit offset") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.limit(10)) + } + val sparkPlan = sparkTestRelation.limit(10) + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + + val connectPlan2 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.offset(2)) + } + val sparkPlan2 = sparkTestRelation.offset(2) + comparePlans(connectPlan2.analyze, sparkPlan2.analyze, false) + + val connectPlan3 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.limit(10).offset(2)) + } + val sparkPlan3 = sparkTestRelation.limit(10).offset(2) + comparePlans(connectPlan3.analyze, sparkPlan3.analyze, false) + + val connectPlan4 = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.offset(2).limit(10)) + } + val sparkPlan4 = sparkTestRelation.offset(2).limit(10) + comparePlans(connectPlan4.analyze, sparkPlan4.analyze, false) + } + private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 7bb2a3356c8f..41a2629db886 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -231,8 +231,8 @@ def __init__(self, child: Optional["LogicalPlan"], limit: int, offset: int = 0) def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: assert self._child is not None plan = proto.Relation() - plan.fetch.input.CopyFrom(self._child.plan(session)) - plan.fetch.limit = self.limit + plan.limit.input.CopyFrom(self._child.plan(session)) + plan.limit.limit = self.limit return plan def print(self, indent: int = 0) -> str: diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 4b5c4e60d57d..c3b7b7ec2eaf 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x9e\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05\x66\x65tch\x18\x08 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x95\x01\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifierB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"d\n\x05\x46\x65tch\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit\x12\x16\n\x06offset\x18\x03 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x95\x01\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifierB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -42,45 +42,47 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 752 - _UNKNOWN._serialized_start = 754 - _UNKNOWN._serialized_end = 763 - _RELATIONCOMMON._serialized_start = 765 - _RELATIONCOMMON._serialized_end = 836 - _SQL._serialized_start = 838 - _SQL._serialized_end = 865 - _READ._serialized_start = 868 - _READ._serialized_end = 1017 - _READ_NAMEDTABLE._serialized_start = 943 - _READ_NAMEDTABLE._serialized_end = 1004 - _PROJECT._serialized_start = 1019 - _PROJECT._serialized_end = 1136 - _FILTER._serialized_start = 1138 - _FILTER._serialized_end = 1250 - _JOIN._serialized_start = 1253 - _JOIN._serialized_end = 1666 - _JOIN_JOINTYPE._serialized_start = 1479 - _JOIN_JOINTYPE._serialized_end = 1666 - _UNION._serialized_start = 1669 - _UNION._serialized_end = 1874 - _UNION_UNIONTYPE._serialized_start = 1790 - _UNION_UNIONTYPE._serialized_end = 1874 - _FETCH._serialized_start = 1876 - _FETCH._serialized_end = 1976 - _AGGREGATE._serialized_start = 1979 - _AGGREGATE._serialized_end = 2304 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2208 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2304 - _SORT._serialized_start = 2307 - _SORT._serialized_end = 2809 - _SORT_SORTFIELD._serialized_start = 2427 - _SORT_SORTFIELD._serialized_end = 2615 - _SORT_SORTDIRECTION._serialized_start = 2617 - _SORT_SORTDIRECTION._serialized_end = 2725 - _SORT_SORTNULLS._serialized_start = 2727 - _SORT_SORTNULLS._serialized_end = 2809 - _LOCALRELATION._serialized_start = 2811 - _LOCALRELATION._serialized_end = 2904 - _SAMPLE._serialized_start = 2907 - _SAMPLE._serialized_end = 3091 + _RELATION._serialized_end = 801 + _UNKNOWN._serialized_start = 803 + _UNKNOWN._serialized_end = 812 + _RELATIONCOMMON._serialized_start = 814 + _RELATIONCOMMON._serialized_end = 885 + _SQL._serialized_start = 887 + _SQL._serialized_end = 914 + _READ._serialized_start = 917 + _READ._serialized_end = 1066 + _READ_NAMEDTABLE._serialized_start = 992 + _READ_NAMEDTABLE._serialized_end = 1053 + _PROJECT._serialized_start = 1068 + _PROJECT._serialized_end = 1185 + _FILTER._serialized_start = 1187 + _FILTER._serialized_end = 1299 + _JOIN._serialized_start = 1302 + _JOIN._serialized_end = 1715 + _JOIN_JOINTYPE._serialized_start = 1528 + _JOIN_JOINTYPE._serialized_end = 1715 + _UNION._serialized_start = 1718 + _UNION._serialized_end = 1923 + _UNION_UNIONTYPE._serialized_start = 1839 + _UNION_UNIONTYPE._serialized_end = 1923 + _LIMIT._serialized_start = 1925 + _LIMIT._serialized_end = 2001 + _OFFSET._serialized_start = 2003 + _OFFSET._serialized_end = 2082 + _AGGREGATE._serialized_start = 2085 + _AGGREGATE._serialized_end = 2410 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2314 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2410 + _SORT._serialized_start = 2413 + _SORT._serialized_end = 2915 + _SORT_SORTFIELD._serialized_start = 2533 + _SORT_SORTFIELD._serialized_end = 2721 + _SORT_SORTDIRECTION._serialized_start = 2723 + _SORT_SORTDIRECTION._serialized_end = 2831 + _SORT_SORTNULLS._serialized_start = 2833 + _SORT_SORTNULLS._serialized_end = 2915 + _LOCALRELATION._serialized_start = 2917 + _LOCALRELATION._serialized_end = 3010 + _SAMPLE._serialized_start = 3013 + _SAMPLE._serialized_end = 3197 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 1036414143a0..3354fc86f45d 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -66,11 +66,12 @@ class Relation(google.protobuf.message.Message): JOIN_FIELD_NUMBER: builtins.int UNION_FIELD_NUMBER: builtins.int SORT_FIELD_NUMBER: builtins.int - FETCH_FIELD_NUMBER: builtins.int + LIMIT_FIELD_NUMBER: builtins.int AGGREGATE_FIELD_NUMBER: builtins.int SQL_FIELD_NUMBER: builtins.int LOCAL_RELATION_FIELD_NUMBER: builtins.int SAMPLE_FIELD_NUMBER: builtins.int + OFFSET_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -87,7 +88,7 @@ class Relation(google.protobuf.message.Message): @property def sort(self) -> global___Sort: ... @property - def fetch(self) -> global___Fetch: ... + def limit(self) -> global___Limit: ... @property def aggregate(self) -> global___Aggregate: ... @property @@ -97,6 +98,8 @@ class Relation(google.protobuf.message.Message): @property def sample(self) -> global___Sample: ... @property + def offset(self) -> global___Offset: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -108,11 +111,12 @@ class Relation(google.protobuf.message.Message): join: global___Join | None = ..., union: global___Union | None = ..., sort: global___Sort | None = ..., - fetch: global___Fetch | None = ..., + limit: global___Limit | None = ..., aggregate: global___Aggregate | None = ..., sql: global___SQL | None = ..., local_relation: global___LocalRelation | None = ..., sample: global___Sample | None = ..., + offset: global___Offset | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -122,14 +126,16 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", - "fetch", - b"fetch", "filter", b"filter", "join", b"join", + "limit", + b"limit", "local_relation", b"local_relation", + "offset", + b"offset", "project", b"project", "read", @@ -155,14 +161,16 @@ class Relation(google.protobuf.message.Message): b"aggregate", "common", b"common", - "fetch", - b"fetch", "filter", b"filter", "join", b"join", + "limit", + b"limit", "local_relation", b"local_relation", + "offset", + b"offset", "project", b"project", "read", @@ -190,11 +198,12 @@ class Relation(google.protobuf.message.Message): "join", "union", "sort", - "fetch", + "limit", "aggregate", "sql", "local_relation", "sample", + "offset", "unknown", ] | None: ... @@ -479,36 +488,57 @@ class Union(google.protobuf.message.Message): global___Union = Union -class Fetch(google.protobuf.message.Message): - """Relation of type [[Fetch]] that is used to read `limit` / `offset` rows from the input relation.""" +class Limit(google.protobuf.message.Message): + """Relation of type [[Limit]] that is used to `limit` rows from the input relation.""" DESCRIPTOR: google.protobuf.descriptor.Descriptor INPUT_FIELD_NUMBER: builtins.int LIMIT_FIELD_NUMBER: builtins.int - OFFSET_FIELD_NUMBER: builtins.int @property def input(self) -> global___Relation: ... limit: builtins.int - offset: builtins.int def __init__( self, *, input: global___Relation | None = ..., limit: builtins.int = ..., - offset: builtins.int = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal["input", b"input"] ) -> builtins.bool: ... def ClearField( + self, field_name: typing_extensions.Literal["input", b"input", "limit", b"limit"] + ) -> None: ... + +global___Limit = Limit + +class Offset(google.protobuf.message.Message): + """Relation of type [[Offset]] that is used to read rows staring from the `offset` on + the input relation. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + OFFSET_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: ... + offset: builtins.int + def __init__( self, - field_name: typing_extensions.Literal[ - "input", b"input", "limit", b"limit", "offset", b"offset" - ], + *, + input: global___Relation | None = ..., + offset: builtins.int = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["input", b"input", "offset", b"offset"] ) -> None: ... -global___Fetch = Fetch +global___Offset = Offset class Aggregate(google.protobuf.message.Message): """Relation of type [[Aggregate]].""" diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index ca2cc216ff2f..bf2ffdf1dc09 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -77,7 +77,7 @@ def test_column_expressions(self): ProtoExpression.UnresolvedAttribute, ) self.assertEqual( - mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"] + mod_fun.unresolved_function.arguments[0].unresolved_attribute.unparsed_identifier, "id" )