diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 30f36fa6ceb5..e93bb02894a0 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -41,6 +41,7 @@ message Relation { Aggregate aggregate = 9; SQL sql = 10; LocalRelation local_relation = 11; + Sample sample = 12; Unknown unknown = 999; } @@ -167,3 +168,12 @@ message LocalRelation { repeated Expression.QualifiedAttribute attributes = 1; // TODO: support local data. } + +// Relation of type [[Sample]] that samples a fraction of the dataset. +message Sample { + Relation input = 1; + double lower_bound = 2; + double upper_bound = 3; + bool with_replacement = 4; + int64 seed = 5; +} diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala index 96d82366fd5d..579f190156ff 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala @@ -46,7 +46,7 @@ package object dsl { .build() def struct( - attrs: proto.Expression.QualifiedAttribute*): proto.Expression.QualifiedAttribute = { + attrs: proto.Expression.QualifiedAttribute*): proto.Expression.QualifiedAttribute = { val structExpr = proto.DataType.Struct.newBuilder() for (attr <- attrs) { val structField = proto.DataType.StructField.newBuilder() @@ -54,7 +54,8 @@ package object dsl { structField.setType(attr.getType) structExpr.addFields(structField) } - proto.Expression.QualifiedAttribute.newBuilder() + proto.Expression.QualifiedAttribute + .newBuilder() .setName(s) .setType(proto.DataType.newBuilder().setStruct(structExpr)) .build() @@ -65,8 +66,9 @@ package object dsl { proto.DataType.newBuilder().setI32(proto.DataType.I32.newBuilder()).build()) private def protoQualifiedAttrWithType( - dataType: proto.DataType): proto.Expression.QualifiedAttribute = - proto.Expression.QualifiedAttribute.newBuilder() + dataType: proto.DataType): proto.Expression.QualifiedAttribute = + proto.Expression.QualifiedAttribute + .newBuilder() .setName(s) .setType(dataType) .build() @@ -180,6 +182,24 @@ package object dsl { .build() } + def sample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): proto.Relation = { + proto.Relation + .newBuilder() + .setSample( + proto.Sample + .newBuilder() + .setInput(logicalPlan) + .setUpperBound(upperBound) + .setLowerBound(lowerBound) + .setWithReplacement(withReplacement) + .setSeed(seed)) + .build() + } + def groupBy(groupingExprs: proto.Expression*)( aggregateExprs: proto.Expression*): proto.Relation = { val agg = proto.Aggregate.newBuilder() diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 46072ec089e0..61352c17a231 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias} import org.apache.spark.sql.types._ final case class InvalidPlanInput( @@ -62,6 +62,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql) case proto.Relation.RelTypeCase.LOCAL_RELATION => transformLocalRelation(rel.getLocalRelation) + case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample) case proto.Relation.RelTypeCase.RELTYPE_NOT_SET => throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.") case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.") @@ -72,6 +73,21 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { session.sessionState.sqlParser.parsePlan(sql.getQuery) } + /** + * All fields of [[proto.Sample]] are optional. However, given those are proto primitive types, + * we cannot differentiate if the fied is not or set when the field's value equals to the type + * default value. In the future if this ever become a problem, one solution could be that to + * wrap such fields into proto messages. + */ + private def transformSample(rel: proto.Sample): LogicalPlan = { + Sample( + rel.getLowerBound, + rel.getUpperBound, + rel.getWithReplacement, + rel.getSeed, + transformRelation(rel.getInput)) + } + private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index d3f286d848a6..b13e74c2125b 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -91,6 +91,15 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { } } + test("Test sample") { + val connectPlan = { + import org.apache.spark.sql.connect.dsl.plans._ + transform(connectTestRelation.sample(0, 0.2, false, 1)) + } + val sparkPlan = sparkTestRelation.sample(0, 0.2, false, 1) + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) + } + test("column alias") { val connectPlan = { import org.apache.spark.sql.connect.dsl.expressions._ @@ -98,6 +107,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { transform(connectTestRelation.select("id".protoAttr.as("id2"))) } val sparkPlan = sparkTestRelation.select($"id".as("id2")) + comparePlans(connectPlan.analyze, sparkPlan.analyze, false) } test("Aggregate with more than 1 grouping expressions") { diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 1280236a1501..99116f1d59af 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xed\x04\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05\x66\x65tch\x18\x08 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"z\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a"\n\nNamedTable\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05partsB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"d\n\x05\x46\x65tch\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit\x12\x16\n\x06offset\x18\x03 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributesB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x9e\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05\x66\x65tch\x18\x08 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"z\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a"\n\nNamedTable\x12\x14\n\x05parts\x18\x01 \x03(\tR\x05partsB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"d\n\x05\x46\x65tch\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit\x12\x16\n\x06offset\x18\x03 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -42,43 +42,45 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 703 - _UNKNOWN._serialized_start = 705 - _UNKNOWN._serialized_end = 714 - _RELATIONCOMMON._serialized_start = 716 - _RELATIONCOMMON._serialized_end = 787 - _SQL._serialized_start = 789 - _SQL._serialized_end = 816 - _READ._serialized_start = 818 - _READ._serialized_end = 940 - _READ_NAMEDTABLE._serialized_start = 893 - _READ_NAMEDTABLE._serialized_end = 927 - _PROJECT._serialized_start = 942 - _PROJECT._serialized_end = 1059 - _FILTER._serialized_start = 1061 - _FILTER._serialized_end = 1173 - _JOIN._serialized_start = 1176 - _JOIN._serialized_end = 1589 - _JOIN_JOINTYPE._serialized_start = 1402 - _JOIN_JOINTYPE._serialized_end = 1589 - _UNION._serialized_start = 1592 - _UNION._serialized_end = 1797 - _UNION_UNIONTYPE._serialized_start = 1713 - _UNION_UNIONTYPE._serialized_end = 1797 - _FETCH._serialized_start = 1799 - _FETCH._serialized_end = 1899 - _AGGREGATE._serialized_start = 1902 - _AGGREGATE._serialized_end = 2227 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2131 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2227 - _SORT._serialized_start = 2230 - _SORT._serialized_end = 2732 - _SORT_SORTFIELD._serialized_start = 2350 - _SORT_SORTFIELD._serialized_end = 2538 - _SORT_SORTDIRECTION._serialized_start = 2540 - _SORT_SORTDIRECTION._serialized_end = 2648 - _SORT_SORTNULLS._serialized_start = 2650 - _SORT_SORTNULLS._serialized_end = 2732 - _LOCALRELATION._serialized_start = 2734 - _LOCALRELATION._serialized_end = 2827 + _RELATION._serialized_end = 752 + _UNKNOWN._serialized_start = 754 + _UNKNOWN._serialized_end = 763 + _RELATIONCOMMON._serialized_start = 765 + _RELATIONCOMMON._serialized_end = 836 + _SQL._serialized_start = 838 + _SQL._serialized_end = 865 + _READ._serialized_start = 867 + _READ._serialized_end = 989 + _READ_NAMEDTABLE._serialized_start = 942 + _READ_NAMEDTABLE._serialized_end = 976 + _PROJECT._serialized_start = 991 + _PROJECT._serialized_end = 1108 + _FILTER._serialized_start = 1110 + _FILTER._serialized_end = 1222 + _JOIN._serialized_start = 1225 + _JOIN._serialized_end = 1638 + _JOIN_JOINTYPE._serialized_start = 1451 + _JOIN_JOINTYPE._serialized_end = 1638 + _UNION._serialized_start = 1641 + _UNION._serialized_end = 1846 + _UNION_UNIONTYPE._serialized_start = 1762 + _UNION_UNIONTYPE._serialized_end = 1846 + _FETCH._serialized_start = 1848 + _FETCH._serialized_end = 1948 + _AGGREGATE._serialized_start = 1951 + _AGGREGATE._serialized_end = 2276 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2180 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2276 + _SORT._serialized_start = 2279 + _SORT._serialized_end = 2781 + _SORT_SORTFIELD._serialized_start = 2399 + _SORT_SORTFIELD._serialized_end = 2587 + _SORT_SORTDIRECTION._serialized_start = 2589 + _SORT_SORTDIRECTION._serialized_end = 2697 + _SORT_SORTNULLS._serialized_start = 2699 + _SORT_SORTNULLS._serialized_end = 2781 + _LOCALRELATION._serialized_start = 2783 + _LOCALRELATION._serialized_end = 2876 + _SAMPLE._serialized_start = 2879 + _SAMPLE._serialized_end = 3063 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 952f476c9127..402c08083c7a 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -70,6 +70,7 @@ class Relation(google.protobuf.message.Message): AGGREGATE_FIELD_NUMBER: builtins.int SQL_FIELD_NUMBER: builtins.int LOCAL_RELATION_FIELD_NUMBER: builtins.int + SAMPLE_FIELD_NUMBER: builtins.int UNKNOWN_FIELD_NUMBER: builtins.int @property def common(self) -> global___RelationCommon: ... @@ -94,6 +95,8 @@ class Relation(google.protobuf.message.Message): @property def local_relation(self) -> global___LocalRelation: ... @property + def sample(self) -> global___Sample: ... + @property def unknown(self) -> global___Unknown: ... def __init__( self, @@ -109,6 +112,7 @@ class Relation(google.protobuf.message.Message): aggregate: global___Aggregate | None = ..., sql: global___SQL | None = ..., local_relation: global___LocalRelation | None = ..., + sample: global___Sample | None = ..., unknown: global___Unknown | None = ..., ) -> None: ... def HasField( @@ -132,6 +136,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "sample", + b"sample", "sort", b"sort", "sql", @@ -163,6 +169,8 @@ class Relation(google.protobuf.message.Message): b"read", "rel_type", b"rel_type", + "sample", + b"sample", "sort", b"sort", "sql", @@ -186,6 +194,7 @@ class Relation(google.protobuf.message.Message): "aggregate", "sql", "local_relation", + "sample", "unknown", ] | None: ... @@ -694,3 +703,49 @@ class LocalRelation(google.protobuf.message.Message): ) -> None: ... global___LocalRelation = LocalRelation + +class Sample(google.protobuf.message.Message): + """Relation of type [[Sample]] that samples a fraction of the dataset.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + LOWER_BOUND_FIELD_NUMBER: builtins.int + UPPER_BOUND_FIELD_NUMBER: builtins.int + WITH_REPLACEMENT_FIELD_NUMBER: builtins.int + SEED_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: ... + lower_bound: builtins.float + upper_bound: builtins.float + with_replacement: builtins.bool + seed: builtins.int + def __init__( + self, + *, + input: global___Relation | None = ..., + lower_bound: builtins.float = ..., + upper_bound: builtins.float = ..., + with_replacement: builtins.bool = ..., + seed: builtins.int = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "input", + b"input", + "lower_bound", + b"lower_bound", + "seed", + b"seed", + "upper_bound", + b"upper_bound", + "with_replacement", + b"with_replacement", + ], + ) -> None: ... + +global___Sample = Sample