diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index dc4329c60324..4910a5b59273 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -553,6 +553,7 @@ VARIANT Functions try_variant_get variant_get try_parse_json + to_variant_object XML Functions diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index ad6dbbf58e48..031e7c22542d 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2071,6 +2071,13 @@ def try_parse_json(col: "ColumnOrName") -> Column: try_parse_json.__doc__ = pysparkfuncs.try_parse_json.__doc__ +def to_variant_object(col: "ColumnOrName") -> Column: + return _invoke_function("to_variant_object", _to_col(col)) + + +to_variant_object.__doc__ = pysparkfuncs.to_variant_object.__doc__ + + def parse_json(col: "ColumnOrName") -> Column: return _invoke_function("parse_json", _to_col(col)) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 1bdd2dbd8f01..b6499eb1546e 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16308,6 +16308,55 @@ def try_parse_json( return _invoke_function("try_parse_json", _to_java_column(col)) +@_try_remote_functions +def to_variant_object( + col: "ColumnOrName", +) -> Column: + """ + Converts a column containing nested inputs (array/map/struct) into a variants where maps and + structs are converted to variant objects which are unordered unlike SQL structs. Input maps can + only have string keys. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + a column with a nested schema or column name + + Returns + ------- + :class:`~pyspark.sql.Column` + a new column of VariantType. + + Examples + -------- + Example 1: Converting an array containing a nested struct into a variant + + >>> from pyspark.sql import functions as sf + >>> from pyspark.sql.types import ArrayType, StructType, StructField, StringType, MapType + >>> schema = StructType([ + ... StructField("i", StringType(), True), + ... StructField("v", ArrayType(StructType([ + ... StructField("a", MapType(StringType(), StringType()), True) + ... ]), True)) + ... ]) + >>> data = [("1", [{"a": {"b": 2}}])] + >>> df = spark.createDataFrame(data, schema) + >>> df.select(sf.to_variant_object(df.v)) + DataFrame[to_variant_object(v): variant] + >>> df.select(sf.to_variant_object(df.v)).show(truncate=False) + +--------------------+ + |to_variant_object(v)| + +--------------------+ + |[{"a":{"b":"2"}}] | + +--------------------+ + """ + from pyspark.sql.classic.column import _to_java_column + + return _invoke_function("to_variant_object", _to_java_column(col)) + + @_try_remote_functions def parse_json( col: "ColumnOrName", @@ -16467,7 +16516,7 @@ def schema_of_variant(v: "ColumnOrName") -> Column: -------- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) >>> df.select(schema_of_variant(parse_json(df.json)).alias("r")).collect() - [Row(r='STRUCT')] + [Row(r='OBJECT')] """ from pyspark.sql.classic.column import _to_java_column @@ -16495,7 +16544,7 @@ def schema_of_variant_agg(v: "ColumnOrName") -> Column: -------- >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) >>> df.select(schema_of_variant_agg(parse_json(df.json)).alias("r")).collect() - [Row(r='STRUCT')] + [Row(r='OBJECT')] """ from pyspark.sql.classic.column import _to_java_column diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index f7f2485a43e1..a0ab9bc9c7d4 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -1326,8 +1326,8 @@ def check(resultDf, expected): self.assertEqual([r[0] for r in resultDf.collect()], expected) check(df.select(F.is_variant_null(v)), [False, False]) - check(df.select(F.schema_of_variant(v)), ["STRUCT", "STRUCT"]) - check(df.select(F.schema_of_variant_agg(v)), ["STRUCT"]) + check(df.select(F.schema_of_variant(v)), ["OBJECT", "OBJECT"]) + check(df.select(F.schema_of_variant_agg(v)), ["OBJECT"]) check(df.select(F.variant_get(v, "$.a", "int")), [1, None]) check(df.select(F.variant_get(v, "$.b", "int")), [None, 2]) @@ -1365,6 +1365,13 @@ def test_try_parse_json(self): self.assertEqual("""{"a":1}""", actual[0]["var"]) self.assertEqual(None, actual[1]["var"]) + def test_to_variant_object(self): + df = self.spark.createDataFrame([(1, {"a": 1})], "i int, v struct") + actual = df.select( + F.to_json(F.to_variant_object(df.v)).alias("var"), + ).collect() + self.assertEqual("""{"a":1}""", actual[0]["var"]) + def test_schema_of_csv(self): with self.assertRaises(PySparkTypeError) as pe: F.schema_of_csv(1) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 771fb824a70c..a8b2044ba8a4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6851,6 +6851,18 @@ object functions { */ def parse_json(json: Column): Column = Column.fn("parse_json", json) + /** + * Converts a column containing nested inputs (array/map/struct) into a variants where maps and + * structs are converted to variant objects which are unordered unlike SQL structs. Input maps can + * only have string keys. + * + * @param col + * a column with a nested schema or column name. + * @group variant_funcs + * @since 4.0.0 + */ + def to_variant_object(col: Column): Column = Column.fn("to_variant_object", col) + /** * Check if a variant value is a variant null. Returns true if and only if the input is a * variant null and false otherwise (including in the case of SQL NULL). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index dfe1bd12bb7f..75e1ab86f177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -839,6 +839,7 @@ object FunctionRegistry { expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), expression[SchemaOfVariant]("schema_of_variant"), expression[SchemaOfVariantAgg]("schema_of_variant_agg"), + expression[ToVariantObject]("to_variant_object"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4a2b4b28e690..7a2799e99fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -128,7 +128,10 @@ object Cast extends QueryErrorsBase { case (TimestampType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) - case (_, VariantType) => variant.VariantGet.checkDataType(from) + // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain + // lossless equivalents for these types. The `to_variant_object` expression can be used instead + // to convert data of these types to Variant Objects. + case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canAnsiCast(fromType, toType) && resolvableNullability(fn, tn) @@ -237,7 +240,10 @@ object Cast extends QueryErrorsBase { case (_: NumericType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) - case (_, VariantType) => variant.VariantGet.checkDataType(from) + // Structs and Maps can't be cast to Variants since the Variant spec does not yet contain + // lossless equivalents for these types. The `to_variant_object` expression can be used instead + // to convert data of these types to Variant Objects. + case (_, VariantType) => variant.VariantGet.checkDataType(from, allowStructsAndMaps = false) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canCast(fromType, toType) && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index bbf554d384b1..487985b4770e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -126,7 +126,7 @@ object VariantExpressionEvalUtils { buildVariant(builder, element, elementType) } builder.finishWritingArray(start, offsets) - case MapType(StringType, valueType, _) => + case MapType(_: StringType, valueType, _) => val data = input.asInstanceOf[MapData] val keys = data.keyArray() val values = data.valueArray() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index bd956fa5c00e..2c8ca1e8bb2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -117,6 +117,73 @@ case class IsVariantNull(child: Expression) extends UnaryExpression copy(child = newChild) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr) - Convert a nested input (array/map/struct) into a variant where maps and structs are converted to variant objects which are unordered unlike SQL structs. Input maps can only have string keys.", + examples = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(array(1, 2, 3)); + [1,2,3] + > SELECT _FUNC_(array(named_struct('a', 1))); + [{"a":1}] + > SELECT _FUNC_(array(map("a", 2))); + [{"a":2}] + """, + since = "4.0.0", + group = "variant_funcs") +// scalastyle:on line.size.limit +case class ToVariantObject(child: Expression) + extends UnaryExpression + with NullIntolerant + with QueryErrorsBase { + + override val dataType: DataType = VariantType + + // Only accept nested types at the root but any types can be nested inside. + override def checkInputDataTypes(): TypeCheckResult = { + val checkResult: Boolean = child.dataType match { + case _: StructType | _: ArrayType | _: MapType => + VariantGet.checkDataType(child.dataType, allowStructsAndMaps = true) + case _ => false + } + if (!checkResult) { + DataTypeMismatch( + errorSubClass = "CAST_WITHOUT_SUGGESTION", + messageParameters = + Map("srcType" -> toSQLType(child.dataType), "targetType" -> toSQLType(VariantType))) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def prettyName: String = "to_variant_object" + + override protected def withNewChildInternal(newChild: Expression): ToVariantObject = + copy(child = newChild) + + protected override def nullSafeEval(input: Any): Any = + VariantExpressionEvalUtils.castToVariant(input, child.dataType) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childCode = child.genCode(ctx) + val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") + val fromArg = ctx.addReferenceObj("from", child.dataType) + val javaType = JavaCode.javaType(VariantType) + val code = + code""" + ${childCode.code} + boolean ${ev.isNull} = ${childCode.isNull}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(VariantType)}; + if (!${childCode.isNull}) { + ${ev.value} = $cls.castToVariant(${childCode.value}, $fromArg); + } + """ + ev.copy(code = code) + } +} + object VariantPathParser extends RegexParsers { // A path segment in the `VariantGet` expression represents either an object key access or an // array index access. @@ -260,13 +327,16 @@ case object VariantGet { * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset * of them. For nested types, we reject map types with a non-string key type. */ - def checkDataType(dataType: DataType): Boolean = dataType match { + def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean = + dataType match { case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType | VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType => true - case ArrayType(elementType, _) => checkDataType(elementType) - case MapType(_: StringType, valueType, _) => checkDataType(valueType) - case StructType(fields) => fields.forall(f => checkDataType(f.dataType)) + case ArrayType(elementType, _) => checkDataType(elementType, allowStructsAndMaps) + case MapType(_: StringType, valueType, _) if allowStructsAndMaps => + checkDataType(valueType, allowStructsAndMaps) + case StructType(fields) if allowStructsAndMaps => + fields.forall(f => checkDataType(f.dataType, allowStructsAndMaps)) case _ => false } @@ -635,7 +705,7 @@ object VariantExplode { > SELECT _FUNC_(parse_json('null')); VOID > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]')); - ARRAY> + ARRAY> """, since = "4.0.0", group = "variant_funcs" @@ -666,7 +736,24 @@ object SchemaOfVariant { /** The actual implementation of the `SchemaOfVariant` expression. */ def schemaOfVariant(input: VariantVal): UTF8String = { val v = new Variant(input.getValue, input.getMetadata) - UTF8String.fromString(schemaOf(v).sql) + UTF8String.fromString(printSchema(schemaOf(v))) + } + + /** + * Similar to `dataType.sql`. The only difference is that `StructType` is shown as + * `OBJECT<...>` rather than `STRUCT<...>`. + * SchemaOfVariant expressions use the Struct DataType to denote the Object type in the variant + * spec. However, the Object type is not equivalent to the struct type as an Object represents an + * unordered bag of key-value pairs while the Struct type is ordered. + */ + def printSchema(dataType: DataType): String = dataType match { + case StructType(fields) => + def printField(f: StructField): String = + s"${QuotingUtils.quoteIfNeeded(f.name)}: ${printSchema(f.dataType)}" + + s"OBJECT<${fields.map(printField).mkString(", ")}>" + case ArrayType(elementType, _) => s"ARRAY<${printSchema(elementType)}>" + case _ => dataType.sql } /** @@ -731,7 +818,7 @@ object SchemaOfVariant { > SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j); BIGINT > SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j); - STRUCT + OBJECT """, since = "4.0.0", group = "variant_funcs") @@ -767,7 +854,8 @@ case class SchemaOfVariantAgg( override def merge(buffer: DataType, input: DataType): DataType = SchemaOfVariant.mergeSchema(buffer, input) - override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql) + override def eval(buffer: DataType): Any = + UTF8String.fromString(SchemaOfVariant.printSchema(buffer)) override def serialize(buffer: DataType): Array[Byte] = buffer.json.getBytes("UTF-8") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index fb0bf63c0112..a2c22b9a0c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -949,12 +949,24 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { ) } - test("cast to variant") { - def check[T : TypeTag](input: T, expectedJson: String): Unit = { - val cast = Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) - checkEvaluation(StructsToJson(Map.empty, cast), expectedJson) + test("cast to variant/to_variant_object") { + def check[T : TypeTag](input: T, expectedJson: String, + toVariantObject: Boolean = false): Unit = { + val expr = + if (toVariantObject) ToVariantObject(Literal.create(input)) + else Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) + checkEvaluation(StructsToJson(Map.empty, expr), expectedJson) } + def checkFailure[T: TypeTag](input: T, toVariantObject: Boolean = false): Unit = { + val expr = + if (toVariantObject) ToVariantObject(Literal.create(input)) + else Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) + val resolvedExpr = ResolveTimeZone.resolveTimeZones(expr) + assert(!resolvedExpr.resolved) + } + + // cast to variant - success cases check(null.asInstanceOf[String], null) // The following tests cover all allowed scalar types. for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 5.0D)) { @@ -1023,17 +1035,52 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""") - check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""") check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""), parseJson("""{"c": true,"d": {"e": "str"}}""")), """[{"a":1,"b":[1,2,3]},{"c":true,"d":{"e":"str"}}]""") - val struct = Literal.create( + + // cast to variant - failure cases - struct and map types + val mp = Map("z" -> 1, "y" -> 2, "x" -> 3) + val arrayMp = Array(Map("z" -> 1, "y" -> 2, "x" -> 3)) + val arrayArrayMp = Array(Array(Map("z" -> 1, "y" -> 2, "x" -> 3))) + checkFailure(mp) + checkFailure(arrayMp) + checkFailure(arrayArrayMp) + val struct = Literal.create(create_row(1), + StructType(Array(StructField("a", IntegerType)))) + checkFailure(struct) + val arrayStruct = Literal.create( + Array(create_row(1)), + ArrayType(StructType(Array(StructField("a", IntegerType))))) + checkFailure(arrayStruct) + + // to_variant_object - success cases - nested types + check(Array(1, 2, 3), "[1,2,3]", toVariantObject = true) + check(mp, """{"x":3,"y":2,"z":1}""", toVariantObject = true) + check(arrayMp, """[{"x":3,"y":2,"z":1}]""", toVariantObject = true) + check(arrayArrayMp, """[[{"x":3,"y":2,"z":1}]]""", toVariantObject = true) + check(struct, """{"a":1}""", toVariantObject = true) + check(arrayStruct, """[{"a":1}]""", toVariantObject = true) + val complexStruct = Literal.create( Row( Seq("123", "true", "f"), Map("a" -> "123", "b" -> "true", "c" -> "f"), + Map("a" -> Row(132)), Row(0)), - StructType.fromDDL("c ARRAY,b MAP,a STRUCT")) - check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""") + StructType.fromDDL("c ARRAY,b MAP,d MAP>," + + "a STRUCT")) + check(complexStruct, + """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"],""" + + """"d":{"a":{"i":132}}}""", + toVariantObject = true) + check(ymArrLit, """["INTERVAL '0' MONTH","INTERVAL""" + + """ '2147483647' MONTH","INTERVAL '-2147483647' MONTH"]""", toVariantObject = true) + + // to_variant_object - failure cases - non-nested types or map with non-string key + checkFailure(1, toVariantObject = true) + checkFailure(true, toVariantObject = true) + checkFailure(Literal.create(Literal.create(Period.ofMonths(0))), toVariantObject = true) + checkFailure(Map(1 -> 1), toVariantObject = true) } test("schema_of_variant - unknown type") { @@ -1092,7 +1139,7 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { val results = mutable.HashMap.empty[(Literal, Literal), String] for (i <- inputs) { - val inputType = if (i.value == null) "VOID" else i.dataType.sql + val inputType = if (i.value == null) "VOID" else SchemaOfVariant.printSchema(i.dataType) results.put((nul, i), inputType) results.put((i, i), inputType) } @@ -1106,14 +1153,24 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { results.put((timestamp, timestampNtz), "TIMESTAMP") results.put((float, decimal), "DOUBLE") results.put((array1, array2), "ARRAY") - results.put((struct1, struct2), "STRUCT") + results.put((struct1, struct2), "OBJECT") results.put((dtInterval1, dtInterval2), "INTERVAL DAY TO SECOND") results.put((ymInterval1, ymInterval2), "INTERVAL YEAR TO MONTH") for (i1 <- inputs) { for (i2 <- inputs) { val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT")) - val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType))) + val elem1 = + if (i1.dataType.isInstanceOf[ArrayType] || i1.dataType.isInstanceOf[MapType] || + i1.dataType.isInstanceOf[StructType]) { + ToVariantObject(i1) + } else Cast(i1, VariantType) + val elem2 = + if (i2.dataType.isInstanceOf[ArrayType] || i2.dataType.isInstanceOf[MapType] || + i2.dataType.isInstanceOf[StructType]) { + ToVariantObject(i2) + } else Cast(i2, VariantType) + val array = CreateArray(Seq(elem1, elem2)) checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>") } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 316e5e967672..f53b3874e6b8 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -451,6 +451,7 @@ | org.apache.spark.sql.catalyst.expressions.variant.ParseJsonExpressionBuilder | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct | | org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariantAgg | schema_of_variant_agg | SELECT schema_of_variant_agg(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j) | struct | +| org.apache.spark.sql.catalyst.expressions.variant.ToVariantObject | to_variant_object | SELECT to_variant_object(named_struct('a', 1, 'b', 2)) | struct | | org.apache.spark.sql.catalyst.expressions.variant.TryParseJsonExpressionBuilder | try_parse_json | SELECT try_parse_json('{"a":1,"b":0.8}') | struct | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 7d0f6c401c0d..d31a281b7cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1576,9 +1576,9 @@ class CollationSQLExpressionsSuite SchemaOfVariantTestCase("null", "UTF8_BINARY", "VOID"), SchemaOfVariantTestCase("[]", "UTF8_LCASE", "ARRAY"), SchemaOfVariantTestCase("[{\"a\":true,\"b\":0}]", "UNICODE", - "ARRAY>"), + "ARRAY>"), SchemaOfVariantTestCase("[{\"A\":\"x\",\"B\":-1.00}]", "UNICODE_CI", - "ARRAY>") + "ARRAY>") ) // Supported collations @@ -1607,9 +1607,9 @@ class CollationSQLExpressionsSuite SchemaOfVariantAggTestCase("('1'), ('2'), ('3')", "UTF8_BINARY", "BIGINT"), SchemaOfVariantAggTestCase("('true'), ('false'), ('true')", "UTF8_LCASE", "BOOLEAN"), SchemaOfVariantAggTestCase("('{\"a\": 1}'), ('{\"b\": true}'), ('{\"c\": 1.23}')", - "UNICODE", "STRUCT"), + "UNICODE", "OBJECT"), SchemaOfVariantAggTestCase("('{\"A\": \"x\"}'), ('{\"B\": 9.99}'), ('{\"C\": 0}')", - "UNICODE_CI", "STRUCT") + "UNICODE_CI", "OBJECT") ) // Supported collations diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 4a20ec4af7e6..3224baf42f3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,7 +16,10 @@ */ package org.apache.spark.sql +import org.apache.spark.sql.QueryTest.sameRows +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.expressions.variant.{ToVariantObject, VariantExpressionEvalUtils} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -158,6 +161,34 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { checkAnswer(variantDF, Seq(Row(expected))) } + test("to_variant_object - Codegen Support") { + Seq("CODEGEN_ONLY", "NO_CODEGEN").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + val schema = StructType(Array( + StructField("v", StructType(Array(StructField("a", IntegerType)))) + )) + val data = Seq(Row(Row(1)), Row(Row(2)), Row(Row(3)), Row(null)) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + val variantDF = df.select(to_variant_object(col("v"))) + val plan = variantDF.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec] == (codegenMode == "CODEGEN_ONLY")) + val v1 = VariantExpressionEvalUtils.castToVariant(InternalRow(1), + StructType(Array(StructField("a", IntegerType)))) + val v2 = VariantExpressionEvalUtils.castToVariant(InternalRow(2), + StructType(Array(StructField("a", IntegerType)))) + val v3 = VariantExpressionEvalUtils.castToVariant(InternalRow(3), + StructType(Array(StructField("a", IntegerType)))) + val v4 = VariantExpressionEvalUtils.castToVariant(null, + StructType(Array(StructField("a", IntegerType)))) + val expected = Seq(Row(new VariantVal(v1.getValue, v1.getMetadata)), + Row(new VariantVal(v2.getValue, v2.getMetadata)), + Row(new VariantVal(v3.getValue, v3.getMetadata)), + Row(new VariantVal(v4.getValue, v4.getMetadata))) + sameRows(variantDF.collect().toSeq, expected) + } + } + } + test("schema_of_variant") { def check(json: String, expected: String): Unit = { val df = Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))") @@ -181,8 +212,8 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("1E0", "DOUBLE") check("true", "BOOLEAN") check("\"2000-01-01\"", "STRING") - check("""{"a":0}""", "STRUCT") - check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT, b: STRUCT>") + check("""{"a":0}""", "OBJECT") + check("""{"b": {"c": "c"}, "a":["a"]}""", "OBJECT, b: OBJECT>") check("[]", "ARRAY") check("[false]", "ARRAY") check("[null, 1, 1.0]", "ARRAY") @@ -192,11 +223,11 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { check("[1.1, 11111111111111111111111111111111111111]", "ARRAY") check("[1, \"1\"]", "ARRAY") check("[{}, true]", "ARRAY") - check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") - check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") check( """[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""", - "ARRAY>" + "ARRAY>" ) } @@ -233,7 +264,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { // Literal input. checkAnswer( sql("""SELECT schema_of_variant_agg(parse_json('{"a": [1, 2, 3]}'))"""), - Seq(Row("STRUCT>"))) + Seq(Row("OBJECT>"))) // Non-grouping aggregation. def checkNonGrouping(input: Seq[String], expected: String): Unit = { @@ -241,20 +272,20 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { Seq(Row(expected))) } - checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "STRUCT>") - checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "STRUCT>") - checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY>") - checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), "STRUCT") + checkNonGrouping(Seq("""{"a": [1, 2, 3]}"""), "OBJECT>") + checkNonGrouping((0 to 100).map(i => s"""{"a": [$i]}"""), "OBJECT>") + checkNonGrouping(Seq("""[{"a": 1}, {"b": 2}]"""), "ARRAY>") + checkNonGrouping(Seq("""{"a": [1, 2, 3]}""", """{"a": "banana"}"""), "OBJECT") checkNonGrouping(Seq("""{"a": "banana"}""", """{"b": "apple"}"""), - "STRUCT") - checkNonGrouping(Seq("""{"a": "data"}""", null), "STRUCT") + "OBJECT") + checkNonGrouping(Seq("""{"a": "data"}""", null), "OBJECT") checkNonGrouping(Seq(null, null), "VOID") - checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "STRUCT") + checkNonGrouping(Seq("""{"a": null}""", """{"a": null}"""), "OBJECT") checkNonGrouping(Seq( """{"hi":[]}""", """{"hi":[{},{}]}""", """{"hi":[{"it's":[{"me":[{"a": 1}]}]}]}"""), - "STRUCT>>>>>>") + "OBJECT>>>>>>") // Grouping aggregation. withView("v") { @@ -263,11 +294,11 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { (id, json) }.toDF("id", "json").createTempView("v") checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 2"), - Seq(Row("STRUCT>"), Row("STRUCT>"))) + Seq(Row("OBJECT>"), Row("OBJECT>"))) checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 3"), - Seq.fill(3)(Row("STRUCT>"))) + Seq.fill(3)(Row("OBJECT>"))) checkAnswer(sql("select schema_of_variant_agg(parse_json(json)) from v group by id % 4"), - Seq.fill(3)(Row("STRUCT>")) ++ Seq(Row("STRUCT>"))) + Seq.fill(3)(Row("OBJECT>")) ++ Seq(Row("OBJECT>"))) } } @@ -279,22 +310,33 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { dataVector.appendLong(456) val array = new ColumnarArray(dataVector, 0, 4) val variant = Cast(Literal(array, ArrayType(LongType)), VariantType).eval() + val variant2 = ToVariantObject(Literal(array, ArrayType(LongType))).eval() assert(variant.toString == "[null,123,null,456]") + assert(variant2.toString == "[null,123,null,456]") dataVector.close() } - test("cast to variant with scan input") { - withTempPath { dir => - val path = dir.getAbsolutePath - val input = Seq(Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str"))) - val schema = StructType.fromDDL( - "a array, m map, s struct") - spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) - val df = spark.read.parquet(path).selectExpr( - s"cast(cast(a as variant) as ${schema(0).dataType.sql})", - s"cast(cast(m as variant) as ${schema(1).dataType.sql})", - s"cast(cast(s as variant) as ${schema(2).dataType.sql})") - checkAnswer(df, input) + test("cast to variant/to_variant_object with scan input") { + Seq("NO_CODEGEN", "CODEGEN_ONLY").foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + withTempPath { dir => + val path = dir.getAbsolutePath + val input = Seq( + Row(Array(1, null), Map("k1" -> null, "k2" -> false), Row(null, "str")), + Row(null, null, null) + ) + val schema = StructType.fromDDL( + "a array, m map, s struct") + spark.createDataFrame(spark.sparkContext.parallelize(input), schema).write.parquet(path) + val df = spark.read.parquet(path).selectExpr( + s"cast(cast(a as variant) as ${schema(0).dataType.sql})", + s"cast(to_variant_object(m) as ${schema(1).dataType.sql})", + s"cast(to_variant_object(s) as ${schema(2).dataType.sql})") + checkAnswer(df, input) + val plan = df.queryExecution.executedPlan + assert(plan.isInstanceOf[WholeStageCodegenExec] == (codegenMode == "CODEGEN_ONLY")) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 0c8b0b501951..7ef30894c71a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -87,8 +87,8 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval def rows(results: Any*): Seq[Row] = results.map(Row(_)) checkAnswer(df.select(is_variant_null(v)), rows(false, false)) - checkAnswer(df.select(schema_of_variant(v)), rows("STRUCT", "STRUCT")) - checkAnswer(df.select(schema_of_variant_agg(v)), rows("STRUCT")) + checkAnswer(df.select(schema_of_variant(v)), rows("OBJECT", "OBJECT")) + checkAnswer(df.select(schema_of_variant_agg(v)), rows("OBJECT")) checkAnswer(df.select(variant_get(v, "$.a", "int")), rows(1, null)) checkAnswer(df.select(variant_get(v, "$.b", "int")), rows(null, 2)) @@ -806,4 +806,11 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval checkSize(structResult.getAs[VariantVal](0), 5, 10, 5, 10) checkSize(structResult.getAs[VariantVal](1), 2, 4, 2, 4) } + + test("schema_of_variant(object)") { + for (expr <- Seq("schema_of_variant", "schema_of_variant_agg")) { + val q = s"""select $expr(parse_json('{"STRUCT": {"!special!": true}}'))""" + checkAnswer(sql(q), Row("""OBJECT>""")) + } + } }