Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions connector/connect/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ When contributing a new client please be aware that we strive to have a common
user experience across all languages. Please follow the below guidelines:

* [Connection string configuration](docs/client-connection-string.md)
* [Adding new messages](docs/adding-proto-messages.md) in the Spark Connect protocol.
40 changes: 40 additions & 0 deletions connector/connect/docs/adding-proto-messages.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Required, Optional and default values

Spark Connect adopts proto3, which does not support the use of the `required` constraint anymore.
For non-message proto fields, there is also no `has_field_name` functions to easy tell
if a filed is set or not-set. (Read [proto3 field rules](https://developers.google.com/protocol-buffers/docs/proto3#specifying_field_rules))


### Required field

When adding fields that have required semantics, developers are required to follow
the outlined process. Fields that are semantically required for the server to
correctly process the incoming message must be documented with `(Required)`. For scalar
fields the server will not perform any additional input validation. For compound fields,
the server will perform minimal checks to avoid null pointer exceptions but will not
perform any semantic validation.

Example:
```protobuf
message DataSource {
// (Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
string format = 1;
}
```


### Optional fields

Semantically optional fields must be marked by `optional`. The server side will
then use this information to branch into different behaviors based on the presence or absence of this field.

Due to the lack of configurable default values for scalar types, the pure presence of
an optional value does not define its default value. The server side implementation will interpret the observed value based on its own rules.

Example:
```protobuf
message DataSource {
// (Optional) If not set, Spark will infer the schema.
optional string schema = 2;
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ message Expression {

message Alias {
Expression expr = 1;
string name = 2;
repeated string name = 2;
// Alias metadata expressed as a JSON map.
optional string metadata = 3;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,28 @@ package object dsl {
implicit class DslExpression(val expr: Expression) {
def as(alias: String): Expression = Expression
.newBuilder()
.setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.setAlias(Expression.Alias.newBuilder().addName(alias).setExpr(expr))
.build()

def as(alias: String, metadata: String): Expression = Expression
.newBuilder()
.setAlias(
Expression.Alias
.newBuilder()
.setExpr(expr)
.addName(alias)
.setMetadata(metadata)
.build())
.build()

def as(alias: Seq[String]): Expression = Expression
.newBuilder()
.setAlias(
Expression.Alias
.newBuilder()
.setExpr(expr)
.addAllName(alias.asJava)
.build())
.build()

def <(other: Expression): Expression =
Expand All @@ -101,6 +122,13 @@ package object dsl {
Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e))
.build()

def proto_explode(e: Expression): Expression =
Expression
.newBuilder()
.setUnresolvedFunction(
Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e))
.build()

/**
* Create an unresolved function from name parts.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._

import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType}

/**
* This object offers methods to convert to/from connect proto to catalyst types.
Expand All @@ -32,6 +32,7 @@ object DataTypeProtoConverter {
case proto.DataType.KindCase.I32 => IntegerType
case proto.DataType.KindCase.STRING => StringType
case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
Expand All @@ -44,6 +45,10 @@ object DataTypeProtoConverter {
StructType.apply(structFields)
}

private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = {
MapType(toCatalystType(t.getKey), toCatalystType(t.getValue))
}

def toConnectProtoType(t: DataType): proto.DataType = {
t match {
case IntegerType =>
Expand All @@ -54,11 +59,24 @@ object DataTypeProtoConverter {
proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
case struct: StructType =>
toConnectProtoStructType(struct)
case map: MapType => toConnectProtoMapType(map)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}

def toConnectProtoMapType(schema: MapType): proto.DataType = {
proto.DataType
.newBuilder()
.setMap(
proto.DataType.Map
.newBuilder()
.setKey(toConnectProtoType(schema.keyType))
.setValue(toConnectProtoType(schema.valueType))
.build())
.build()
}

def toConnectProtoStructType(schema: StructType): proto.DataType = {
val struct = proto.DataType.Struct.newBuilder()
for (structField <- schema.fields) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
Expand Down Expand Up @@ -338,7 +338,9 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
transformExpressionString(exp.getExpressionString)
case _ => throw InvalidPlanInput()
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
}
}

Expand Down Expand Up @@ -412,7 +414,20 @@ class SparkConnectPlanner(session: SparkSession) {
}

private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
Alias(transformExpression(alias.getExpr), alias.getName)()
if (alias.getNameCount == 1) {
val md = if (alias.hasMetadata()) {
Some(Metadata.fromJson(alias.getMetadata))
} else {
None
}
Alias(transformExpression(alias.getExpr), alias.getName(0))(explicitMetadata = md)
} else {
if (alias.hasMetadata) {
throw new InvalidPlanInput(
"Alias expressions with more than 1 identifier must not use optional metadata.")
}
MultiAlias(transformExpression(alias.getExpr), alias.getNameList.asScala.toSeq)
}
}

private def transformExpressionString(expr: proto.Expression.ExpressionString): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType}

/**
* This suite is based on connect DSL and test that given same dataframe operations, whether
Expand All @@ -50,6 +50,9 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()))

lazy val connectTestRelationMap =
createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))()))

lazy val sparkTestRelation: DataFrame =
spark.createDataFrame(
new java.util.ArrayList[Row](),
Expand All @@ -60,6 +63,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))

lazy val sparkTestRelationMap: DataFrame =
spark.createDataFrame(
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", MapType(StringType, StringType)))))

lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))

test("Basic select") {
Expand Down Expand Up @@ -140,10 +148,35 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}

test("column alias") {
test("SPARK-40809: column alias") {
// Simple Test.
val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))
comparePlans(connectPlan, sparkPlan)

// Scalar columns with metadata
val mdJson = "{\"max\": 99}"
comparePlans(
connectTestRelation.select("id".protoAttr.as("id2", mdJson)),
sparkTestRelation.select(Column("id").as("id2", Metadata.fromJson(mdJson))))

comparePlans(
connectTestRelationMap.select(proto_explode("id".protoAttr).as(Seq("a", "b"))),
sparkTestRelationMap.select(explode(Column("id")).as(Seq("a", "b"))))

// Metadata must only be specified for regular Aliases.
assertThrows[InvalidPlanInput] {
val attr = proto_explode("id".protoAttr)
val alias = proto.Expression.Alias
.newBuilder()
.setExpr(attr)
.addName("a")
.addName("b")
.setMetadata(mdJson)
.build()
transform(
connectTestRelationMap.select(proto.Expression.newBuilder().setAlias(alias).build()))
}
}

test("Aggregate with more than 1 grouping expressions") {
Expand Down
66 changes: 66 additions & 0 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import uuid
from typing import cast, get_args, TYPE_CHECKING, Callable, Any

import json
import decimal
import datetime

Expand Down Expand Up @@ -82,6 +83,71 @@ def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
def __str__(self) -> str:
...

def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
"""
Returns this column aliased with a new name or names (in the case of expressions that
return more than one column, such as explode).

.. versionadded:: 3.4.0

Parameters
----------
alias : str
desired column names (collects all positional arguments passed)

Other Parameters
----------------
metadata: dict
a dict of information to be stored in ``metadata`` attribute of the
corresponding :class:`StructField <pyspark.sql.types.StructField>` (optional, keyword
only argument)

Returns
-------
:class:`Column`
Column representing whether each element of Column is aliased with new name or names.

Examples
--------
>>> df = spark.createDataFrame(
... [(2, "Alice"), (5, "Bob")], ["age", "name"])
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
>>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max']
99
"""
metadata = kwargs.pop("metadata", None)
assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
return ColumnAlias(self, list(alias), metadata)


class ColumnAlias(Expression):
def __init__(self, parent: Expression, alias: list[str], metadata: Any):

self._alias = alias
self._metadata = metadata
self._parent = parent

def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
if len(self._alias) == 1:
exp = proto.Expression()
exp.alias.name.append(self._alias[0])
exp.alias.expr.CopyFrom(self._parent.to_plan(session))

if self._metadata:
exp.alias.metadata = json.dumps(self._metadata)
return exp
else:
if self._metadata:
raise ValueError("metadata can only be provided for a single column")
exp = proto.Expression()
exp.alias.name.extend(self._alias)
exp.alias.expr.CopyFrom(self._parent.to_plan(session))
return exp

def __str__(self) -> str:
return f"Alias({self._parent}, ({','.join(self._alias)}))"


class LiteralExpression(Expression):
"""A literal expression.
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def withPlan(cls, plan: plan.LogicalPlan, session: "RemoteSparkSession") -> "Dat
new_frame._plan = plan
return new_frame

def select(self, *cols: ColumnOrName) -> "DataFrame":
def select(self, *cols: "ExpressionOrString") -> "DataFrame":
return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session)

def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
Expand Down
Loading