Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,44 @@ package object dsl {
.build()
}

/**
* Create an unresolved function from name parts.
*
* @param nameParts
* @param args
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(nameParts: Seq[String], args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addAllParts(nameParts.asJava)
.addAllArguments(args.asJava))
.build()
}

/**
* Creates an UnresolvedFunction from a single identifier.
*
* @param name
* @param args
* @return
* Expression wrapping the unresolved function.
*/
def callFunction(name: String, args: Seq[proto.Expression]): proto.Expression = {
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addParts(name)
.addAllArguments(args.asJava))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression
.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
limitExpr = expressions.Literal(limit.getLimit, IntegerType))
}

private def lookupFunction(name: String, args: Seq[Expression]): Expression = {
UnresolvedFunction(Seq(name), args, isDistinct = false)
}

/**
* Translates a scalar function from proto to the Catalyst expression.
*
Expand All @@ -211,21 +207,14 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
* @return
*/
private def transformScalarFunction(fun: proto.Expression.UnresolvedFunction): Expression = {
val funName = fun.getPartsList.asScala.mkString(".")
funName match {
case "gt" =>
assert(fun.getArgumentsCount == 2, "`gt` function must have two arguments.")
expressions.GreaterThan(
transformExpression(fun.getArguments(0)),
transformExpression(fun.getArguments(1)))
case "eq" =>
assert(fun.getArgumentsCount == 2, "`eq` function must have two arguments.")
expressions.EqualTo(
transformExpression(fun.getArguments(0)),
transformExpression(fun.getArguments(1)))
case _ =>
lookupFunction(funName, fun.getArgumentsList.asScala.map(transformExpression).toSeq)
if (fun.getPartsCount == 1 && fun.getParts(0).contains(".")) {
throw new IllegalArgumentException(
"Function identifier must be passed as sequence of name parts.")
}
UnresolvedFunction(
fun.getPartsList.asScala.toSeq,
fun.getArgumentsList.asScala.map(transformExpression).toSeq,
isDistinct = false)
}

private def transformAlias(alias: proto.Expression.Alias): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val joinCondition = proto.Expression.newBuilder.setUnresolvedFunction(
proto.Expression.UnresolvedFunction.newBuilder
.addAllParts(Seq("eq").asJava)
.addAllParts(Seq("==").asJava)
.addArguments(unresolvedAttribute)
.addArguments(unresolvedAttribute)
.build())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan.analyze, sparkPlan.analyze, false)
}

test("UnresolvedFunction resolution.") {
{
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: only import them once in test("UnresolvedFunction resolution.")?

assertThrows[IllegalArgumentException] {
transform(connectTestRelation.select(callFunction("default.hex", Seq("id".protoAttr))))
}
}

val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(
connectTestRelation.select(callFunction(Seq("default", "hex"), Seq("id".protoAttr))))
}

assertThrows[UnsupportedOperationException] {
connectPlan.analyze
}

val validPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
transform(connectTestRelation.select(callFunction(Seq("hex"), Seq("id".protoAttr))))
}
assert(validPlan.analyze != null)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to compare it with the catalyst plan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not to validate that the catalyst plan exists, but really just that existing functions are actually resolved. The !=null is mostly to have any assertion and not throw.

}

test("Basic filter") {
val connectPlan = {
import org.apache.spark.sql.connect.dsl.expressions._
Expand Down
80 changes: 40 additions & 40 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,51 @@
import pyspark.sql.connect.proto as proto


def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["ColumnRef", Any], "Expression"]:
def _(self: "ColumnRef", other: Any) -> "Expression":
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
return ScalarFunctionExpression(name, self, other)
else:
return ScalarFunctionExpression(name, other, self)

return _


class Expression(object):
"""
Expression base class.
"""

__gt__ = _bin_op(">")
__lt__ = _bin_op(">")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_bin_op("<")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think this was a mistake.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__add__ = _bin_op("+")
__sub__ = _bin_op("-")
__mul__ = _bin_op("*")
__div__ = _bin_op("/")
__truediv__ = _bin_op("/")
__mod__ = _bin_op("%")
__radd__ = _bin_op("+", reverse=True)
__rsub__ = _bin_op("-", reverse=True)
__rmul__ = _bin_op("*", reverse=True)
__rdiv__ = _bin_op("/", reverse=True)
__rtruediv__ = _bin_op("/", reverse=True)
__pow__ = _bin_op("pow")
__rpow__ = _bin_op("pow", reverse=True)
__ge__ = _bin_op(">=")
__le__ = _bin_op("<=")

def __eq__(self, other: Any) -> "Expression": # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("==", self, other)

def __init__(self) -> None:
pass

Expand Down Expand Up @@ -73,20 +113,6 @@ def __str__(self) -> str:
return f"Literal({self._value})"


def _bin_op(
name: str, doc: str = "binary function", reverse: bool = False
) -> Callable[["ColumnRef", Any], Expression]:
def _(self: "ColumnRef", other: Any) -> Expression:
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
if not reverse:
return ScalarFunctionExpression(name, self, other)
else:
return ScalarFunctionExpression(name, other, self)

return _


class ColumnRef(Expression):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should rename this to AttributeReference to avoid confusion e.g., I think it's a bit mixed with Column interface that is the user-facing interface. @amaliujia and @cloud-fan

Should better to keep it matched with either Catalyst internal types or user-facing Spark SQL interface classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a discussion about this, but this is an unrelated change to this one. I think we should probably call Expression -> Column and ColumnRef -> AttributeReference but it will require some more digging what the right name should be. However, as said, that's independent of this change.

Copy link
Contributor

@amaliujia amaliujia Oct 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I have been thinking this ColumnRef thing. Let's revisit it on the naming, etc. in the future.

"""Represents a column reference. There is no guarantee that this column
actually exists. In the context of this project, we refer by its name and
Expand All @@ -105,32 +131,6 @@ def name(self) -> str:
"""Returns the qualified name of the column reference."""
return ".".join(self._parts)

__gt__ = _bin_op("gt")
__lt__ = _bin_op("lt")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
__div__ = _bin_op("divide")
__truediv__ = _bin_op("divide")
__mod__ = _bin_op("modulo")
__radd__ = _bin_op("plus", reverse=True)
__rsub__ = _bin_op("minus", reverse=True)
__rmul__ = _bin_op("multiply", reverse=True)
__rdiv__ = _bin_op("divide", reverse=True)
__rtruediv__ = _bin_op("divide", reverse=True)
__pow__ = _bin_op("pow")
__rpow__ = _bin_op("pow", reverse=True)
__ge__ = _bin_op("greterEquals")
__le__ = _bin_op("lessEquals")

def __eq__(self, other: Any) -> Expression: # type: ignore[override]
"""Returns a binary expression with the current column as the left
side and the other expression as the right side.
"""
if isinstance(other, get_args(PrimitiveType)):
other = LiteralExpression(other)
return ScalarFunctionExpression("eq", self, other)

def to_plan(self, session: Optional["RemoteSparkSession"]) -> proto.Expression:
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import unittest
import tempfile

import pandas
Copy link
Member

@HyukjinKwon HyukjinKwon Oct 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm .. we gotta fix this or do something. pandas isn't a required library for SQL package. Should probably skip this tests when pandas is not installed for now until we have a clear way to handle this. (see pyspark.testing.sqlutils.have_pandas and pyspark.sql.tests.test_arrow_map

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, nothing in Spark Connect will work atm without pandas because we always call toPandas in the collection of the result. Let me know what you want to do.


from pyspark.sql import SparkSession, Row
from pyspark.sql.connect.client import RemoteSparkSession
from pyspark.sql.connect.function_builder import udf
from pyspark.sql.connect.functions import lit
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.utils import ReusedPySparkTestCase

Expand Down Expand Up @@ -79,6 +82,15 @@ def test_simple_explain_string(self):
result = df.explain()
self.assertGreater(len(result), 0)

def test_simple_binary_expressions(self):
"""Test complex expression"""
df = self.connect.read.table(self.tbl_name)
pd = df.select(df.id).where(df.id % lit(30) == lit(0)).sort(df.id.asc()).toPandas()
self.assertEqual(len(pd.index), 4)

res = pandas.DataFrame(data={"id": [0, 30, 60, 90]})
self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}")


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from pyspark.testing.connectutils import PlanOnlyTestFixture
from pyspark.sql.connect.proto import Expression as ProtoExpression
import pyspark.sql.connect as c
import pyspark.sql.connect.plan as p
import pyspark.sql.connect.column as col
Expand Down Expand Up @@ -51,6 +52,34 @@ def test_column_literals(self):
plan = fun.lit(10).to_plan(None)
self.assertIs(plan.literal.i32, 10)

def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
the protobuf structure."""
df = c.DataFrame.withPlan(p.Read("table"))

expr = df.id % fun.lit(10) == fun.lit(10)
expr_plan = expr.to_plan(None)
self.assertIsNotNone(expr_plan.unresolved_function)
self.assertEqual(expr_plan.unresolved_function.parts[0], "==")

lit_fun = expr_plan.unresolved_function.arguments[1]
self.assertIsInstance(lit_fun, ProtoExpression)
self.assertIsInstance(lit_fun.literal, ProtoExpression.Literal)
self.assertEqual(lit_fun.literal.i32, 10)

mod_fun = expr_plan.unresolved_function.arguments[0]
self.assertIsInstance(mod_fun, ProtoExpression)
self.assertIsInstance(mod_fun.unresolved_function, ProtoExpression.UnresolvedFunction)
self.assertEqual(len(mod_fun.unresolved_function.arguments), 2)
self.assertIsInstance(mod_fun.unresolved_function.arguments[0], ProtoExpression)
self.assertIsInstance(
mod_fun.unresolved_function.arguments[0].unresolved_attribute,
ProtoExpression.UnresolvedAttribute,
)
self.assertEqual(
mod_fun.unresolved_function.arguments[0].unresolved_attribute.parts, ["id"]
)


if __name__ == "__main__":
import unittest
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_filter(self):
plan.root.filter.condition.unresolved_function, proto.Expression.UnresolvedFunction
)
)
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, ["gt"])
self.assertEqual(plan.root.filter.condition.unresolved_function.parts, [">"])
self.assertEqual(len(plan.root.filter.condition.unresolved_function.arguments), 2)

def test_relation_alias(self):
Expand Down