Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ VARIANT Functions
try_variant_get
variant_get
try_parse_json
to_variant_object


XML Functions
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,6 +2071,13 @@ def try_parse_json(col: "ColumnOrName") -> Column:
try_parse_json.__doc__ = pysparkfuncs.try_parse_json.__doc__


def to_variant_object(col: "ColumnOrName") -> Column:
return _invoke_function("to_variant_object", _to_col(col))


to_variant_object.__doc__ = pysparkfuncs.to_variant_object.__doc__


def parse_json(col: "ColumnOrName") -> Column:
return _invoke_function("parse_json", _to_col(col))

Expand Down
53 changes: 51 additions & 2 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16308,6 +16308,55 @@ def try_parse_json(
return _invoke_function("try_parse_json", _to_java_column(col))


@_try_remote_functions
def to_variant_object(
col: "ColumnOrName",
) -> Column:
"""
Converts a column containing nested inputs (array/map/struct) into a variants where maps and
structs are converted to variant objects which are unordered unlike SQL structs. Input maps can
only have string keys.

.. versionadded:: 4.0.0

Parameters
----------
col : :class:`~pyspark.sql.Column` or str
a column with a nested schema or column name

Returns
-------
:class:`~pyspark.sql.Column`
a new column of VariantType.

Examples
--------
Example 1: Converting an array containing a nested struct into a variant

>>> from pyspark.sql import functions as sf
>>> from pyspark.sql.types import ArrayType, StructType, StructField, StringType, MapType
>>> schema = StructType([
... StructField("i", StringType(), True),
... StructField("v", ArrayType(StructType([
... StructField("a", MapType(StringType(), StringType()), True)
... ]), True))
... ])
>>> data = [("1", [{"a": {"b": 2}}])]
>>> df = spark.createDataFrame(data, schema)
>>> df.select(sf.to_variant_object(df.v))
DataFrame[to_variant_object(v): variant]
>>> df.select(sf.to_variant_object(df.v)).show(truncate=False)
+--------------------+
|to_variant_object(v)|
+--------------------+
|[{"a":{"b":"2"}}] |
+--------------------+
"""
from pyspark.sql.classic.column import _to_java_column

return _invoke_function("to_variant_object", _to_java_column(col))


@_try_remote_functions
def parse_json(
col: "ColumnOrName",
Expand Down Expand Up @@ -16467,7 +16516,7 @@ def schema_of_variant(v: "ColumnOrName") -> Column:
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df.select(schema_of_variant(parse_json(df.json)).alias("r")).collect()
[Row(r='STRUCT<a: BIGINT>')]
[Row(r='OBJECT<a: BIGINT>')]
"""
from pyspark.sql.classic.column import _to_java_column

Expand Down Expand Up @@ -16495,7 +16544,7 @@ def schema_of_variant_agg(v: "ColumnOrName") -> Column:
--------
>>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ])
>>> df.select(schema_of_variant_agg(parse_json(df.json)).alias("r")).collect()
[Row(r='STRUCT<a: BIGINT>')]
[Row(r='OBJECT<a: BIGINT>')]
"""
from pyspark.sql.classic.column import _to_java_column

Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,8 +1326,8 @@ def check(resultDf, expected):
self.assertEqual([r[0] for r in resultDf.collect()], expected)

check(df.select(F.is_variant_null(v)), [False, False])
check(df.select(F.schema_of_variant(v)), ["STRUCT<a: BIGINT>", "STRUCT<b: BIGINT>"])
check(df.select(F.schema_of_variant_agg(v)), ["STRUCT<a: BIGINT, b: BIGINT>"])
check(df.select(F.schema_of_variant(v)), ["OBJECT<a: BIGINT>", "OBJECT<b: BIGINT>"])
check(df.select(F.schema_of_variant_agg(v)), ["OBJECT<a: BIGINT, b: BIGINT>"])

check(df.select(F.variant_get(v, "$.a", "int")), [1, None])
check(df.select(F.variant_get(v, "$.b", "int")), [None, 2])
Expand Down Expand Up @@ -1365,6 +1365,13 @@ def test_try_parse_json(self):
self.assertEqual("""{"a":1}""", actual[0]["var"])
self.assertEqual(None, actual[1]["var"])

def test_to_variant_object(self):
df = self.spark.createDataFrame([(1, {"a": 1})], "i int, v struct<a int>")
actual = df.select(
F.to_json(F.to_variant_object(df.v)).alias("var"),
).collect()
self.assertEqual("""{"a":1}""", actual[0]["var"])

def test_schema_of_csv(self):
with self.assertRaises(PySparkTypeError) as pe:
F.schema_of_csv(1)
Expand Down
12 changes: 12 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6851,6 +6851,18 @@ object functions {
*/
def parse_json(json: Column): Column = Column.fn("parse_json", json)

/**
* Converts a column containing nested inputs (array/map/struct) into a variants where maps and
* structs are converted to variant objects which are unordered unlike SQL structs. Input maps can
* only have string keys.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the input has no array/map/struct, this function is noop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, the type check will fail for it.

*
* @param col
* a column with a nested schema or column name.
* @group variant_funcs
* @since 4.0.0
*/
def to_variant_object(col: Column): Column = Column.fn("to_variant_object", col)

/**
* Check if a variant value is a variant null. Returns true if and only if the input is a
* variant null and false otherwise (including in the case of SQL NULL).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ object FunctionRegistry {
expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
expression[SchemaOfVariant]("schema_of_variant"),
expression[SchemaOfVariantAgg]("schema_of_variant_agg"),
expression[ToVariantObject]("to_variant_object"),

// cast
expression[Cast]("cast"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ object Cast extends QueryErrorsBase {
case (TimestampType, _: NumericType) => true

case (VariantType, _) => variant.VariantGet.checkDataType(to)
case (_, VariantType) => variant.VariantGet.checkDataType(from)
// Structs and Maps can't be cast to Variants since the Variant spec does not yet contain
// lossless equivalents for these types. The `to_variant_object` expression can be used instead
// to convert data of these types to Variant Objects.
case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false)

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canAnsiCast(fromType, toType) && resolvableNullability(fn, tn)
Expand Down Expand Up @@ -237,7 +240,10 @@ object Cast extends QueryErrorsBase {
case (_: NumericType, _: NumericType) => true

case (VariantType, _) => variant.VariantGet.checkDataType(to)
case (_, VariantType) => variant.VariantGet.checkDataType(from)
// Structs and Maps can't be cast to Variants since the Variant spec does not yet contain
// lossless equivalents for these types. The `to_variant_object` expression can be used instead
// to convert data of these types to Variant Objects.
case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false)

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ object VariantExpressionEvalUtils {
buildVariant(builder, element, elementType)
}
builder.finishWritingArray(start, offsets)
case MapType(StringType, valueType, _) =>
case MapType(_: StringType, valueType, _) =>
val data = input.asInstanceOf[MapData]
val keys = data.keyArray()
val values = data.valueArray()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.json.JsonInferSchema
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -117,6 +117,73 @@ case class IsVariantNull(child: Expression) extends UnaryExpression
copy(child = newChild)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Convert a nested input (array/map/struct) into a variant where maps and structs are converted to variant objects which are unordered unlike SQL structs. Input maps can only have string keys.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we define the element order in the resulting variant object? Random?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and no. From a logical perspective, the keys should be thought of as random and the users should not assume anything about the order.
However, in the spec, the field IDs are sorted based on the lexicographic order of the keys. This is to make it possible to binary search for the required key.

examples = """
Examples:
> SELECT _FUNC_(named_struct('a', 1, 'b', 2));
{"a":1,"b":2}
> SELECT _FUNC_(array(1, 2, 3));
[1,2,3]
> SELECT _FUNC_(array(named_struct('a', 1)));
[{"a":1}]
> SELECT _FUNC_(array(map("a", 2)));
[{"a":2}]
""",
since = "4.0.0",
group = "variant_funcs")
// scalastyle:on line.size.limit
case class ToVariantObject(child: Expression)
extends UnaryExpression
with NullIntolerant
with QueryErrorsBase {

override val dataType: DataType = VariantType

// Only accept nested types at the root but any types can be nested inside.
override def checkInputDataTypes(): TypeCheckResult = {
val checkResult: Boolean = child.dataType match {
case _: StructType | _: ArrayType | _: MapType =>
VariantGet.checkDataType(child.dataType, allowStructsAndMaps = true)
case _ => false
}
if (!checkResult) {
DataTypeMismatch(
errorSubClass = "CAST_WITHOUT_SUGGESTION",
messageParameters =
Map("srcType" -> toSQLType(child.dataType), "targetType" -> toSQLType(VariantType)))
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def prettyName: String = "to_variant_object"

override protected def withNewChildInternal(newChild: Expression): ToVariantObject =
copy(child = newChild)

protected override def nullSafeEval(input: Any): Any =
VariantExpressionEvalUtils.castToVariant(input, child.dataType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childCode = child.genCode(ctx)
val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$")
val fromArg = ctx.addReferenceObj("from", child.dataType)
val javaType = JavaCode.javaType(VariantType)
val code =
code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(VariantType)};
if (!${childCode.isNull}) {
${ev.value} = $cls.castToVariant(${childCode.value}, $fromArg);
}
"""
ev.copy(code = code)
}
}

object VariantPathParser extends RegexParsers {
// A path segment in the `VariantGet` expression represents either an object key access or an
// array index access.
Expand Down Expand Up @@ -260,13 +327,16 @@ case object VariantGet {
* Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset
* of them. For nested types, we reject map types with a non-string key type.
*/
def checkDataType(dataType: DataType): Boolean = dataType match {
def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean =
dataType match {
case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType |
VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType =>
true
case ArrayType(elementType, _) => checkDataType(elementType)
case MapType(_: StringType, valueType, _) => checkDataType(valueType)
case StructType(fields) => fields.forall(f => checkDataType(f.dataType))
case ArrayType(elementType, _) => checkDataType(elementType, allowStructsAndMaps)
case MapType(_: StringType, valueType, _) if allowStructsAndMaps =>
checkDataType(valueType, allowStructsAndMaps)
case StructType(fields) if allowStructsAndMaps =>
fields.forall(f => checkDataType(f.dataType, allowStructsAndMaps))
case _ => false
}

Expand Down Expand Up @@ -635,7 +705,7 @@ object VariantExplode {
> SELECT _FUNC_(parse_json('null'));
VOID
> SELECT _FUNC_(parse_json('[{"b":true,"a":0}]'));
ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
ARRAY<OBJECT<a: BIGINT, b: BOOLEAN>>
""",
since = "4.0.0",
group = "variant_funcs"
Expand Down Expand Up @@ -666,7 +736,24 @@ object SchemaOfVariant {
/** The actual implementation of the `SchemaOfVariant` expression. */
def schemaOfVariant(input: VariantVal): UTF8String = {
val v = new Variant(input.getValue, input.getMetadata)
UTF8String.fromString(schemaOf(v).sql)
UTF8String.fromString(printSchema(schemaOf(v)))
}

/**
* Similar to `dataType.sql`. The only difference is that `StructType` is shown as
* `OBJECT<...>` rather than `STRUCT<...>`.
* SchemaOfVariant expressions use the Struct DataType to denote the Object type in the variant
* spec. However, the Object type is not equivalent to the struct type as an Object represents an
* unordered bag of key-value pairs while the Struct type is ordered.
*/
def printSchema(dataType: DataType): String = dataType match {
case StructType(fields) =>
def printField(f: StructField): String =
s"${QuotingUtils.quoteIfNeeded(f.name)}: ${printSchema(f.dataType)}"

s"OBJECT<${fields.map(printField).mkString(", ")}>"
case ArrayType(elementType, _) => s"ARRAY<${printSchema(elementType)}>"
case _ => dataType.sql
}

/**
Expand Down Expand Up @@ -731,7 +818,7 @@ object SchemaOfVariant {
> SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j);
BIGINT
> SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j);
STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)>
OBJECT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)>
""",
since = "4.0.0",
group = "variant_funcs")
Expand Down Expand Up @@ -767,7 +854,8 @@ case class SchemaOfVariantAgg(
override def merge(buffer: DataType, input: DataType): DataType =
SchemaOfVariant.mergeSchema(buffer, input)

override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql)
override def eval(buffer: DataType): Any =
UTF8String.fromString(SchemaOfVariant.printSchema(buffer))

override def serialize(buffer: DataType): Array[Byte] = buffer.json.getBytes("UTF-8")

Expand Down
Loading