Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ message Expression {
// If set, it should end with '.*' and will be parsed by 'parseAttributeName'
// in the server side.
optional string unparsed_target = 1;

// (Optional) The id of corresponding connect plan.
optional int64 plan_id = 2;
}

// Represents all of the input attributes to a given relational operator, for example in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -2109,19 +2109,28 @@ class SparkConnectPlanner(
parser.parseExpression(expr.getExpression)
}

private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): UnresolvedStar = {
if (star.hasUnparsedTarget) {
val target = star.getUnparsedTarget
if (!target.endsWith(".*")) {
throw InvalidPlanInput(
s"UnresolvedStar requires a unparsed target ending with '.*', " +
s"but got $target.")
}
private def transformUnresolvedStar(star: proto.Expression.UnresolvedStar): Expression = {
(star.hasUnparsedTarget, star.hasPlanId) match {
case (false, false) =>
// functions.col("*")
UnresolvedStar(None)

UnresolvedStar(
Some(UnresolvedAttribute.parseAttributeName(target.substring(0, target.length - 2))))
} else {
UnresolvedStar(None)
case (true, false) =>
// functions.col("s.*")
val target = star.getUnparsedTarget
if (!target.endsWith(".*")) {
throw InvalidPlanInput(
s"UnresolvedStar requires a unparsed target ending with '.*', but got $target.")
}
val parts = UnresolvedAttribute.parseAttributeName(target.dropRight(2))
UnresolvedStar(Some(parts))

case (false, true) =>
// dataframe.col("*")
UnresolvedDataFrameStar(star.getPlanId)

case _ =>
throw InvalidPlanInput("UnresolvedStar with both target and plan id is not supported.")
}
}

Expand Down
48 changes: 35 additions & 13 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.streaming.readwriter import DataStreamWriter
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
from pyspark.sql.connect.expressions import (
ColumnReference,
UnresolvedRegex,
UnresolvedStar,
)
from pyspark.sql.connect.functions.builtin import (
_to_col_with_plan_id,
_to_col,
_invoke_function,
col,
Expand Down Expand Up @@ -1702,9 +1705,11 @@ def __getattr__(self, name: str) -> "Column":
error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": name}
)

return _to_col_with_plan_id(
col=name,
plan_id=self._plan._plan_id,
return Column(
ColumnReference(
unparsed_identifier=name,
plan_id=self._plan._plan_id,
)
)

__getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__
Expand All @@ -1719,14 +1724,31 @@ def __getitem__(self, item: Union[Column, List, Tuple]) -> "DataFrame":

def __getitem__(self, item: Union[int, str, Column, List, Tuple]) -> Union[Column, "DataFrame"]:
if isinstance(item, str):
# validate the column name
if not hasattr(self._session, "is_mock_session"):
self.select(item).isLocal()

return _to_col_with_plan_id(
col=item,
plan_id=self._plan._plan_id,
)
if item == "*":
return Column(
UnresolvedStar(
unparsed_target=None,
plan_id=self._plan._plan_id,
)
)
else:
# TODO: revisit vanilla Spark's Dataset.col
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for myself, should revisit the implementation of colRegex

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably skip it in spark connect. It's really a weird feature and non-standard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's off by default anyway, so we can throw a proper error if it's enabled in spark connect.

# if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
# colRegex(colName)
# } else {
# Column(addDataFrameIdToCol(resolve(colName)))
# }

# validate the column name
if not hasattr(self._session, "is_mock_session"):
self.select(item).isLocal()

return Column(
ColumnReference(
unparsed_identifier=item,
plan_id=self._plan._plan_id,
)
)
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
Expand Down
8 changes: 6 additions & 2 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,19 +494,23 @@ def __eq__(self, other: Any) -> bool:


class UnresolvedStar(Expression):
def __init__(self, unparsed_target: Optional[str]):
def __init__(self, unparsed_target: Optional[str], plan_id: Optional[int] = None):
super().__init__()

if unparsed_target is not None:
assert isinstance(unparsed_target, str) and unparsed_target.endswith(".*")

self._unparsed_target = unparsed_target

assert plan_id is None or isinstance(plan_id, int)
self._plan_id = plan_id

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.unresolved_star.SetInParent()
if self._unparsed_target is not None:
expr.unresolved_star.unparsed_target = self._unparsed_target
if self._plan_id is not None:
expr.unresolved_star.plan_id = self._plan_id
return expr

def __repr__(self) -> str:
Expand Down
16 changes: 6 additions & 10 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,6 @@
from pyspark.sql.connect.udtf import UserDefinedTableFunction


def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
Copy link
Contributor Author

@zhengruifeng zhengruifeng Jan 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this helper function due to the behavior difference between Dataset#col and functions#col

def col(colName: String): Column = colName match {
case "*" =>
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
colRegex(colName)
} else {
Column(addDataFrameIdToCol(resolve(colName)))
}
}

def this(name: String) = this(withOrigin {
name match {
case "*" => UnresolvedStar(None)
case _ if name.endsWith(".*") =>
val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2))
UnresolvedStar(Some(parts))
case _ => UnresolvedAttribute.quotedString(name)
}
})

if col == "*":
return Column(UnresolvedStar(unparsed_target=None))
elif col.endswith(".*"):
return Column(UnresolvedStar(unparsed_target=col))
else:
return Column(ColumnReference(unparsed_identifier=col, plan_id=plan_id))


def _to_col(col: "ColumnOrName") -> Column:
assert isinstance(col, (Column, str))
return col if isinstance(col, Column) else column(col)
Expand Down Expand Up @@ -224,7 +215,12 @@ def _options_to_col(options: Dict[str, Any]) -> Column:


def col(col: str) -> Column:
return _to_col_with_plan_id(col=col, plan_id=None)
if col == "*":
return Column(UnresolvedStar(unparsed_target=None))
elif col.endswith(".*"):
return Column(UnresolvedStar(unparsed_target=col))
else:
return Column(ColumnReference(unparsed_identifier=col))


col.__doc__ = pysparkfuncs.col.__doc__
Expand Down
54 changes: 27 additions & 27 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

27 changes: 25 additions & 2 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -880,29 +880,52 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

UNPARSED_TARGET_FIELD_NUMBER: builtins.int
PLAN_ID_FIELD_NUMBER: builtins.int
unparsed_target: builtins.str
"""(Optional) The target of the expansion.

If set, it should end with '.*' and will be parsed by 'parseAttributeName'
in the server side.
"""
plan_id: builtins.int
"""(Optional) The id of corresponding connect plan."""
def __init__(
self,
*,
unparsed_target: builtins.str | None = ...,
plan_id: builtins.int | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target"
"_plan_id",
b"_plan_id",
"_unparsed_target",
b"_unparsed_target",
"plan_id",
b"plan_id",
"unparsed_target",
b"unparsed_target",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"_unparsed_target", b"_unparsed_target", "unparsed_target", b"unparsed_target"
"_plan_id",
b"_plan_id",
"_unparsed_target",
b"_unparsed_target",
"plan_id",
b"plan_id",
"unparsed_target",
b"unparsed_target",
],
) -> None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
) -> typing_extensions.Literal["plan_id"] | None: ...
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_unparsed_target", b"_unparsed_target"]
) -> typing_extensions.Literal["unparsed_target"] | None: ...
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,44 @@ def test_invalid_column(self):
):
cdf1.select(cdf2.a).schema

def test_invalid_star(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI run this test only in Connect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between Connect and Classic for this test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [4]: cdf1 = spark.createDataFrame([Row(a=1, b=2, c=3)])

In [5]: cdf2 = spark.createDataFrame([Row(a=2, b=0)])

In [6]: cdf3 = cdf1.select(cdf1.a)

In [7]: cdf3.select(cdf1["*"]).schema
...
AnalysisException: [MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT] Resolved attribute(s) "b", "c" missing from "a" in operator !Project [a#0L, b#1L, c#2L].  SQLSTATE: XX000;
!Project [a#0L, b#1L, c#2L]
+- Project [a#0L]
   +- LogicalRDD [a#0L, b#1L, c#2L], false


In [8]: cdf1.select(cdf2["*"]).schema
...
AnalysisException: [MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_APPEAR_IN_OPERATION] Resolved attribute(s) "a", "b" missing from "a", "b", "c" in operator !Project [a#6L, b#7L]. Attribute(s) with the same name appear in the operation: "a", "b".
Please check if the right attribute(s) are used. SQLSTATE: XX000;
!Project [a#6L, b#7L]
+- LogicalRDD [a#0L, b#1L, c#2L], false


In [9]: cdf1.join(cdf1).select(cdf1["*"]).schema
Out[9]: StructType([StructField('a', LongType(), True), StructField('b', LongType(), True), StructField('c', LongType(), True)])

cdf1.join(cdf1).select(cdf1["*"]) won't fail due to AMBIGUOUS_COLUMN_REFERENCE

data1 = [Row(a=1, b=2, c=3)]
cdf1 = self.connect.createDataFrame(data1)

data2 = [Row(a=2, b=0)]
cdf2 = self.connect.createDataFrame(data2)

# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
cdf3 = cdf1.select(cdf1.a)
cdf3.select(cdf1["*"]).schema

# Can find the target plan node, but fail to resolve with it
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
# column 'a has been replaced
cdf3 = cdf1.withColumn("a", CF.lit(0))
cdf3.select(cdf1["*"]).schema

# Can not find the target plan node by plan id
with self.assertRaisesRegex(
AnalysisException,
"CANNOT_RESOLVE_DATAFRAME_COLUMN",
):
cdf1.select(cdf2["*"]).schema

# cdf1["*"] exists on both side
with self.assertRaisesRegex(
AnalysisException,
"AMBIGUOUS_COLUMN_REFERENCE",
):
cdf1.join(cdf1).select(cdf1["*"]).schema

def test_collect(self):
cdf = self.connect.read.table(self.tbl_name)
sdf = self.spark.read.table(self.tbl_name)
Expand Down
35 changes: 35 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ def test_range(self):
self.assertEqual(self.spark.range(-2).count(), 0)
self.assertEqual(self.spark.range(3).count(), 3)

def test_dataframe_star(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI run this test in both connect and vanilla

df1 = self.spark.createDataFrame([{"a": 1}])
df2 = self.spark.createDataFrame([{"a": 1, "b": "v"}])
df3 = df2.withColumnsRenamed({"a": "x", "b": "y"})

df = df1.join(df2)
self.assertEqual(df.columns, ["a", "a", "b"])
self.assertEqual(df.select(df1["*"]).columns, ["a"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])

df = df1.join(df2).withColumn("c", lit(0))
self.assertEqual(df.columns, ["a", "a", "b", "c"])
self.assertEqual(df.select(df1["*"]).columns, ["a"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])

df = df1.join(df2, "a")
self.assertEqual(df.columns, ["a", "b"])
self.assertEqual(df.select(df1["*"]).columns, ["a"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])

df = df1.join(df2, "a").withColumn("c", lit(0))
self.assertEqual(df.columns, ["a", "b", "c"])
self.assertEqual(df.select(df1["*"]).columns, ["a"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])

df = df2.join(df3)
self.assertEqual(df.columns, ["a", "b", "x", "y"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])

df = df2.join(df3).withColumn("c", lit(0))
self.assertEqual(df.columns, ["a", "b", "x", "y", "c"])
self.assertEqual(df.select(df2["*"]).columns, ["a", "b"])
self.assertEqual(df.select(df3["*"]).columns, ["x", "y"])

def test_self_join(self):
df1 = self.spark.range(10).withColumn("a", lit(0))
df2 = df1.withColumnRenamed("a", "b")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
q: Seq[LogicalPlan]): Expression = e match {
case u: UnresolvedAttribute =>
resolveDataFrameColumn(u, q).getOrElse(u)
case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
case u: UnresolvedDataFrameStar =>
resolveDataFrameStar(u, q)
case _ if e.containsAnyPattern(UNRESOLVED_ATTRIBUTE, UNRESOLVED_DF_STAR) =>
e.mapChildren(c => tryResolveDataFrameColumns(c, q))
case _ => e
}
Expand All @@ -510,7 +512,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveColumn(u)
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
}
resolved
}
Expand Down Expand Up @@ -588,4 +590,45 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
(filtered, matched)
}

private def resolveDataFrameStar(
u: UnresolvedDataFrameStar,
q: Seq[LogicalPlan]): ResolvedStar = {
resolveDataFrameStarByPlanId(u, u.planId, q).getOrElse(
// Can not find the target plan node with plan id, e.g.
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2["*"]) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
)
}

private def resolveDataFrameStarByPlanId(
u: UnresolvedDataFrameStar,
id: Long,
q: Seq[LogicalPlan]): Option[ResolvedStar] = {
q.iterator.map(resolveDataFrameStarRecursively(u, id, _))
.foldLeft(Option.empty[ResolvedStar]) {
case (r1, r2) =>
if (r1.nonEmpty && r2.nonEmpty) {
throw QueryCompilationErrors.ambiguousColumnReferences(u)
}
if (r1.nonEmpty) r1 else r2
}
}

private def resolveDataFrameStarRecursively(
u: UnresolvedDataFrameStar,
id: Long,
p: LogicalPlan): Option[ResolvedStar] = {
val resolved = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
Some(ResolvedStar(p.output))
} else {
resolveDataFrameStarByPlanId(u, id, p.children)
}
resolved.filter { r =>
val outputSet = AttributeSet(p.output ++ p.metadataOutput)
r.expressions.forall(_.references.subsetOf(outputSet))
}
}
}
Loading