Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 70 additions & 60 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

66 changes: 57 additions & 9 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -474,27 +474,51 @@ class Expression(google.protobuf.message.Message):

ELEMENT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The element type of the array.

This field is deprecated since Spark 4.1+ and should only be set
if the data_type field is not set. Use data_type field instead.
"""
@property
def elements(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
]: ...
]:
"""The literal values that make up the array elements."""
@property
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array:
"""The type of the array.

If the element type can be inferred from the first element of the elements field,
then you don't need to set data_type.element_type to save space. On the other hand,
redundant type information is also acceptable.
"""
def __init__(
self,
*,
element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
self,
field_name: typing_extensions.Literal[
"data_type", b"data_type", "element_type", b"element_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"element_type", b"element_type", "elements", b"elements"
"data_type",
b"data_type",
"element_type",
b"element_type",
"elements",
b"elements",
],
) -> None: ...

Expand All @@ -505,39 +529,63 @@ class Expression(google.protobuf.message.Message):
VALUE_TYPE_FIELD_NUMBER: builtins.int
KEYS_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
DATA_TYPE_FIELD_NUMBER: builtins.int
@property
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The key type of the map.

This field is deprecated since Spark 4.1+ and should only be set
if the data_type field is not set. Use data_type field instead.
"""
@property
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The value type of the map.

This field is deprecated since Spark 4.1+ and should only be set
if the data_type field is not set. Use data_type field instead.
"""
@property
def keys(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
]: ...
]:
"""The literal keys that make up the map."""
@property
def values(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
]: ...
]:
"""The literal values that make up the map."""
@property
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map:
"""The type of the map.

If the key/value types can be inferred from the first element of the keys/values fields,
then you don't need to set data_type.key_type/data_type.value_type to save space.
On the other hand, redundant type information is also acceptable.
"""
def __init__(
self,
*,
key_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
value_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
keys: collections.abc.Iterable[global___Expression.Literal] | None = ...,
values: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"key_type", b"key_type", "value_type", b"value_type"
"data_type", b"data_type", "key_type", b"key_type", "value_type", b"value_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"data_type",
b"data_type",
"key_type",
b"key_type",
"keys",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,35 @@ class ClientE2ETestSuite
assert(df.count() == 100)
}
}

test("SPARK-52930: the nullability of arrays should be preserved using typedlit") {
val arrays = Seq(
(typedlit(Array[Int]()), false),
(typedlit(Array[Int](1)), false),
(typedlit(Array[Integer]()), true),
(typedlit(Array[Integer](1)), true))
for ((array, containsNull) <- arrays) {
val df = spark.sql("select 1").select(array)
df.createOrReplaceTempView("test_array_nullability")
val schema = spark.sql("select * from test_array_nullability").schema
assert(schema.fields.head.dataType.asInstanceOf[ArrayType].containsNull === containsNull)
}
}

test("SPARK-52930: the nullability of map values should be preserved using typedlit") {
val maps = Seq(
(typedlit(Map[String, Int]()), false),
(typedlit(Map[String, Int]("a" -> 1)), false),
(typedlit(Map[String, Integer]()), true),
(typedlit(Map[String, Integer]("a" -> 1)), true))
for ((map, valueContainsNull) <- maps) {
val df = spark.sql("select 1").select(map)
df.createOrReplaceTempView("test_map_nullability")
val schema = spark.sql("select * from test_map_nullability").schema
assert(
schema.fields.head.dataType.asInstanceOf[MapType].valueContainsNull === valueContainsNull)
}
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,48 @@ message Expression {
}

message Array {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally feel it might be better to introduce new messages for this purpose, so that we can minimize the code changes, and the conversion logic in the server side can be more clear

Copy link
Contributor Author

@heyihong heyihong Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhengruifeng FYI, I have simplified the implementation a bit. It is possible to minimize the code changes and make the conversion logic on the server side clearer without introducing new messages.

DataType element_type = 1;
// (Deprecated) The element type of the array.
//
// This field is deprecated since Spark 4.1+ and should only be set
// if the data_type field is not set. Use data_type field instead.
DataType element_type = 1 [deprecated = true];

// The literal values that make up the array elements.
repeated Literal elements = 2;

// The type of the array.
//
// If the element type can be inferred from the first element of the elements field,
// then you don't need to set data_type.element_type to save space. On the other hand,
// redundant type information is also acceptable.
DataType.Array data_type = 3;
}

message Map {
DataType key_type = 1;
DataType value_type = 2;
// (Deprecated) The key type of the map.
//
// This field is deprecated since Spark 4.1+ and should only be set
// if the data_type field is not set. Use data_type field instead.
DataType key_type = 1 [deprecated = true];

// (Deprecated) The value type of the map.
//
// This field is deprecated since Spark 4.1+ and should only be set
// if the data_type field is not set. Use data_type field instead.
DataType value_type = 2 [deprecated = true];

// The literal keys that make up the map.
repeated Literal keys = 3;

// The literal values that make up the map.
repeated Literal values = 4;

// The type of the map.
//
// If the key/value types can be inferred from the first element of the keys/values fields,
// then you don't need to set data_type.key_type/data_type.value_type to save space.
// On the other hand, redundant type information is also acceptable.
DataType.Map data_type = 5;
}

message Struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoBuilderWithOptions, ToLiteralProtoOptions}
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}

Expand Down Expand Up @@ -65,11 +65,12 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
val builder = proto.Expression.newBuilder()
val n = additionalTransformation.map(_(node)).getOrElse(node)
n match {
case Literal(value, None, _) =>
builder.setLiteral(toLiteralProtoBuilder(value))

case Literal(value, Some(dataType), _) =>
builder.setLiteral(toLiteralProtoBuilder(value, dataType))
case Literal(value, dataTypeOpt, _) =>
builder.setLiteral(
toLiteralProtoBuilderWithOptions(
value,
dataTypeOpt,
ToLiteralProtoOptions(useDeprecatedDataTypeFields = false)))

case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
val escapedName = u.sql
Expand Down
Loading