Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,9 @@ message Sample {
double lower_bound = 2;
double upper_bound = 3;
bool with_replacement = 4;
int64 seed = 5;
Seed seed = 5;

message Seed {
int64 seed = 1;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ package object dsl {
.setUpperBound(upperBound)
.setLowerBound(lowerBound)
.setWithReplacement(withReplacement)
.setSeed(seed))
.setSeed(proto.Sample.Seed.newBuilder().setSeed(seed).build())
.build())
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, LogicalPlan, Sa
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

final case class InvalidPlanInput(
private val message: String = "",
Expand Down Expand Up @@ -80,7 +81,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {

/**
* All fields of [[proto.Sample]] are optional. However, given those are proto primitive types,
* we cannot differentiate if the fied is not or set when the field's value equals to the type
* we cannot differentiate if the field is not or set when the field's value equals to the type
* default value. In the future if this ever become a problem, one solution could be that to
* wrap such fields into proto messages.
*/
Expand All @@ -89,7 +90,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
rel.getLowerBound,
rel.getUpperBound,
rel.getWithReplacement,
rel.getSeed,
if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong,
transformRelation(rel.getInput))
}

Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,33 @@ def sort(self, *cols: "ColumnOrString") -> "DataFrame":
"""Sort by a specific column"""
return DataFrame.withPlan(plan.Sort(self._plan, *cols), session=self._session)

def sample(

@amaliujia amaliujia Oct 19, 2022

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pyspark dataframe API has

    @overload
    def sample(self, fraction: float, seed: Optional[int] = ...) -> "DataFrame":
        ...

    @overload
    def sample(
        self,
        withReplacement: Optional[bool],
        fraction: float,
        seed: Optional[int] = ...,
    ) -> "DataFrame":
        ...

Can we match (as easy as copy the API into connect dataframe.py)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can discard those ones ? @HyukjinKwon

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe my real question was, will we have an issue to be compatible with existing pyspark dataframe code (needs different imports, of course) if we discard such API? I see many other similar API existing for pyspark dataframe.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

users may have to change their codes for this emigration, but I think this is also a chance to make some changes.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We also can go to that direction.

self,
fraction: float,
*,
withReplacement: bool = False,
seed: Optional[int] = None,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just leverage keyword-only argument which will make the logic much simpler. Actually we wanted to do it in PySpark API layer in the past. Since this is a new API layer, I think it;s a good chance to replace them. cc @ueshin

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that's a bit confusing at first glance.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if we can break the signature, it would be:

def sample(
    self,
    fraction: float,
    *,
    withReplacement: Optional[bool] = None,
    seed: Optional[int] = None,
) -> "DataFrame":
    ...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withReplacement can be : bool = False if the default is False.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this idea

) -> "DataFrame":
if not isinstance(fraction, float):
raise TypeError(f"'fraction' must be float, but got {type(fraction).__name__}")
if not isinstance(withReplacement, bool):
raise TypeError(
f"'withReplacement' must be bool, but got {type(withReplacement).__name__}"
)
if seed is not None and not isinstance(seed, int):
raise TypeError(f"'seed' must be None or int, but got {type(seed).__name__}")

return DataFrame.withPlan(
plan.Sample(
child=self._plan,
lower_bound=0.0,
upper_bound=fraction,
with_replacement=withReplacement,
seed=seed,
),
session=self._session,
)

def show(self, n: int, truncate: Optional[Union[bool, int]], vertical: Optional[bool]) -> None:
...

Expand Down
50 changes: 50 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,56 @@ def _repr_html_(self) -> str:
"""


class Sample(LogicalPlan):
def __init__(
self,
child: Optional["LogicalPlan"],
lower_bound: float,
upper_bound: float,
with_replacement: bool,
seed: Optional[int],
) -> None:
super().__init__(child)
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.with_replacement = with_replacement
self.seed = seed

def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation:
assert self._child is not None
plan = proto.Relation()
plan.sample.input.CopyFrom(self._child.plan(session))
plan.sample.lower_bound = self.lower_bound
plan.sample.upper_bound = self.upper_bound
plan.sample.with_replacement = self.with_replacement
if self.seed is not None:
plan.sample.seed.seed = self.seed
return plan

def print(self, indent: int = 0) -> str:
Comment thread
zhengruifeng marked this conversation as resolved.
Outdated
c_buf = self._child.print(indent + LogicalPlan.INDENT) if self._child else ""
return (
f"{' ' * indent}"
f"<Sample lowerBound={self.lower_bound}, upperBound={self.upper_bound}, "
f"withReplacement={self.with_replacement}, seed={self.seed}>"
f"\n{c_buf}"
)

def _repr_html_(self) -> str:
return f"""
<ul>
<li>
<b>Sample</b><br />
LowerBound: {self.lower_bound} <br />
UpperBound: {self.upper_bound} <br />
WithReplacement: {self.with_replacement} <br />
Seed: {self.seed} <br />
{self._child_repr_()}
</li>
</uL>
"""


class Aggregate(LogicalPlan):
MeasureType = Tuple["ExpressionOrString", str]
MeasuresType = Sequence[MeasureType]
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/sql/connect/proto/relations_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8f\x06\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"\x8e\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12-\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08R\x10\x61llColumnsAsKeys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xf0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12.\n\x04seed\x18\x05 \x01(\x0b\x32\x1a.spark.connect.Sample.SeedR\x04seed\x1a\x1a\n\x04Seed\x12\x12\n\x04seed\x18\x01 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3'
)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
Expand Down Expand Up @@ -92,5 +92,7 @@
_LOCALRELATION._serialized_start = 3387
_LOCALRELATION._serialized_end = 3480
_SAMPLE._serialized_start = 3483
_SAMPLE._serialized_end = 3667
_SAMPLE._serialized_end = 3723
_SAMPLE_SEED._serialized_start = 3697
_SAMPLE_SEED._serialized_end = 3723
# @@protoc_insertion_point(module_scope)
19 changes: 16 additions & 3 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,18 @@ class Sample(google.protobuf.message.Message):

DESCRIPTOR: google.protobuf.descriptor.Descriptor

class Seed(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

SEED_FIELD_NUMBER: builtins.int
seed: builtins.int
def __init__(
self,
*,
seed: builtins.int = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["seed", b"seed"]) -> None: ...

INPUT_FIELD_NUMBER: builtins.int
LOWER_BOUND_FIELD_NUMBER: builtins.int
UPPER_BOUND_FIELD_NUMBER: builtins.int
Expand All @@ -849,18 +861,19 @@ class Sample(google.protobuf.message.Message):
lower_bound: builtins.float
upper_bound: builtins.float
with_replacement: builtins.bool
seed: builtins.int
@property
def seed(self) -> global___Sample.Seed: ...
def __init__(
self,
*,
input: global___Relation | None = ...,
lower_bound: builtins.float = ...,
upper_bound: builtins.float = ...,
with_replacement: builtins.bool = ...,
seed: builtins.int = ...,
seed: global___Sample.Seed | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["input", b"input"]
self, field_name: typing_extensions.Literal["input", b"input", "seed", b"seed"]
) -> builtins.bool: ...
def ClearField(
self,
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ def test_offset(self):
offset_plan = df.offset(10)._plan.to_proto(self.connect)
self.assertEqual(offset_plan.root.offset.offset, 10)

def test_sample(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.filter(df.col_name > 3).sample(fraction=0.3)._plan.to_proto(self.connect)
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.3)
self.assertEqual(plan.root.sample.with_replacement, False)
self.assertEqual(plan.root.sample.HasField("seed"), False)

plan = (
df.filter(df.col_name > 3)
.sample(withReplacement=True, fraction=0.4, seed=-1)
._plan.to_proto(self.connect)
)
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.4)
self.assertEqual(plan.root.sample.with_replacement, True)
self.assertEqual(plan.root.sample.seed.seed, -1)

def test_relation_alias(self):
df = self.connect.readTable(table_name=self.tbl_name)
plan = df.alias("table_alias")._plan.to_proto(self.connect)
Expand Down