Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ message Expression {
Expression expr = 1;

// (Required) the data type that the expr to be casted to.
DataType cast_to_type = 2;
oneof cast_to_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: this is a breaking change in the proto message. Ideally, we would use

reserved 2;
oneof cast_to_type {
  DataType type = 3;
  string type_str = 4;
}

Copy link
Contributor Author

@amaliujia amaliujia Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks. I didn't know this way to evolve the proto.

Given Spark Connect is still alpha component though for now we don't need to be enforced to maintain the backwards compatibility.

But when we are ready to leave from the alpha component then we should follow this way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine since this message has not been actually used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are ok now. But later there is indeed a need to build a process/good practice etc. for how to evolve proto without breaking older versions (if possible)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're ok now. We can use Buf to check for breaking changes.

DataType type = 2;
// If this is set, Server will use Catalyst parser to parse this string to DataType.
string type_str = 3;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add this to follow the design principle so far that we should move repeated clients implementation to the server side to reduce client side redundant work.

Otherwise client side will need to implement string to DataType conversion and each client will need to the same thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes back to the discussion that I had in my draft PR if the second argument to the cast() function should be modeled as an expression or not. Right now it takes a value but could be modeled as a string as well.

}
}

message Literal {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,17 @@ package object dsl {
Expression.Cast
.newBuilder()
.setExpr(expr)
.setCastToType(dataType))
.setType(dataType))
.build()

def cast(dataType: String): Expression =
Expression
.newBuilder()
.setCast(
Expression.Cast
.newBuilder()
.setExpr(expr)
.setTypeStr(dataType))
.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,16 @@ class SparkConnectPlanner(session: SparkSession) {
}

private def transformCast(cast: proto.Expression.Cast): Expression = {
Cast(
transformExpression(cast.getExpr),
DataTypeProtoConverter.toCatalystType(cast.getCastToType))
cast.getCastToTypeCase match {
case proto.Expression.Cast.CastToTypeCase.TYPE =>
Cast(
transformExpression(cast.getExpr),
DataTypeProtoConverter.toCatalystType(cast.getType))
case _ =>
Cast(
transformExpression(cast.getExpr),
session.sessionState.sqlParser.parseDataType(cast.getTypeStr))
Comment on lines +527 to +529
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be nice if we could just add a second string argument to the Cast expression so it could resolved the expression automatically? Would this make the design easier? Because then you wouldn't even need a custom expression for cast.

Copy link
Contributor Author

@amaliujia amaliujia Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not following this? Can you give some example proto/sample code?

Are you saying we do not use oneof but add the third string field?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I looked at cast as a unicorn in the context of expressions, it's one of the few where not all argument types resolve to expressions. I was wondering if we could simplify the approach to make the type argument of cast an expression that can resolve as a string then we can do the matching of the expression in the analyzer.

This is very similar to why you added the oneof to the proto message.

}
}

private def transformSetOperation(u: proto.SetOperation): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
connectTestRelation.select("id".protoAttr.cast(
proto.DataType.newBuilder().setString(proto.DataType.String.getDefaultInstance).build())),
sparkTestRelation.select(col("id").cast(StringType)))

comparePlans(
connectTestRelation.select("id".protoAttr.cast("string")),
sparkTestRelation.select(col("id").cast("string")))
}

test("Test Hint") {
Expand Down
87 changes: 2 additions & 85 deletions python/pyspark/sql/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,13 @@

import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
import pyspark.sql.types
from pyspark import cloudpickle
from pyspark.sql.types import (
DataType,
ByteType,
ShortType,
IntegerType,
FloatType,
DateType,
TimestampType,
DayTimeIntervalType,
MapType,
StringType,
CharType,
VarcharType,
StructType,
StructField,
ArrayType,
DoubleType,
LongType,
DecimalType,
BinaryType,
BooleanType,
NullType,
)


Expand Down Expand Up @@ -350,73 +333,7 @@ def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
return self._execute_and_fetch(req)

def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType:
if schema.HasField("null"):
return NullType()
elif schema.HasField("boolean"):
return BooleanType()
elif schema.HasField("binary"):
return BinaryType()
elif schema.HasField("byte"):
return ByteType()
elif schema.HasField("short"):
return ShortType()
elif schema.HasField("integer"):
return IntegerType()
elif schema.HasField("long"):
return LongType()
elif schema.HasField("float"):
return FloatType()
elif schema.HasField("double"):
return DoubleType()
elif schema.HasField("decimal"):
p = schema.decimal.precision if schema.decimal.HasField("precision") else 10
s = schema.decimal.scale if schema.decimal.HasField("scale") else 0
return DecimalType(precision=p, scale=s)
elif schema.HasField("string"):
return StringType()
elif schema.HasField("char"):
return CharType(schema.char.length)
elif schema.HasField("var_char"):
return VarcharType(schema.var_char.length)
elif schema.HasField("date"):
return DateType()
elif schema.HasField("timestamp"):
return TimestampType()
elif schema.HasField("day_time_interval"):
start: Optional[int] = (
schema.day_time_interval.start_field
if schema.day_time_interval.HasField("start_field")
else None
)
end: Optional[int] = (
schema.day_time_interval.end_field
if schema.day_time_interval.HasField("end_field")
else None
)
return DayTimeIntervalType(startField=start, endField=end)
elif schema.HasField("array"):
return ArrayType(
self._proto_schema_to_pyspark_schema(schema.array.element_type),
schema.array.contains_null,
)
elif schema.HasField("struct"):
fields = [
StructField(
f.name,
self._proto_schema_to_pyspark_schema(f.data_type),
f.nullable,
)
for f in schema.struct.fields
]
return StructType(fields)
elif schema.HasField("map"):
return MapType(
self._proto_schema_to_pyspark_schema(schema.map.key_type),
self._proto_schema_to_pyspark_schema(schema.map.value_type),
schema.map.value_contains_null,
)
else:
raise Exception(f"Unsupported data type {schema}")
return types.proto_schema_to_pyspark_data_type(schema)

def schema(self, plan: pb2.Plan) -> StructType:
proto_schema = self._analyze(plan).schema
Expand Down
48 changes: 47 additions & 1 deletion python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
import decimal
import datetime

from pyspark.sql.types import TimestampType, DayTimeIntervalType, DateType
from pyspark.sql.types import TimestampType, DayTimeIntervalType, DataType, DateType

import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.types import pyspark_types_to_proto_types

if TYPE_CHECKING:
from pyspark.sql.connect._typing import ColumnOrName
Expand Down Expand Up @@ -355,6 +356,29 @@ def __repr__(self) -> str:
return f"{self._name}({', '.join([str(arg) for arg in self._args])})"


class CastExpression(Expression):
def __init__(
self,
col: "Column",
data_type: Union[DataType, str],
) -> None:
super().__init__()
self._col = col
self._data_type = data_type

def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
fun = proto.Expression()
fun.cast.expr.CopyFrom(self._col.to_plan(session))
if isinstance(self._data_type, str):
fun.cast.type_str = self._data_type
else:
fun.cast.type.CopyFrom(pyspark_types_to_proto_types(self._data_type))
return fun

def __repr__(self) -> str:
return f"({self._col} ({self._data_type}))"


class Column:
"""
A column in a DataFrame. Column can refer to different things based on the
Expand Down Expand Up @@ -733,6 +757,28 @@ def desc_nulls_last(self) -> "Column":
def name(self) -> str:
return self._expr.name()

def cast(self, dataType: Union[DataType, str]) -> "Column":
"""
Casts the column into type ``dataType``.

.. versionadded:: 3.4.0

Parameters
----------
dataType : :class:`DataType` or str
a DataType or Python string literal with a DDL-formatted string
to use when parsing the column to the same type.

Returns
-------
:class:`Column`
Column representing whether each element of Column is cast into new type.
"""
if isinstance(dataType, (DataType, str)):
return Column(CastExpression(col=self, data_type=dataType))
else:
raise TypeError("unexpected type: %s" % type(dataType))

# TODO(SPARK-41329): solve the circular import between functions.py and
# this class if we want to reuse functions.lit
def _lit(self, x: Any) -> "Column":
Expand Down
Loading