-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-49443][SQL][PYTHON] Implement to_variant_object expression and make schema_of_variant expressions print OBJECT for for Variant Objects #47907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7d817d2
0f56cc9
1e585a0
015c429
0d8dc2b
7b93564
692b4fa
4bea027
6cd2644
5c6da07
aedb70c
a0bc9cc
591c234
3d35da5
8093084
de045d2
1eb3c20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -553,6 +553,7 @@ VARIANT Functions | |
| try_variant_get | ||
| variant_get | ||
| try_parse_json | ||
| to_variant_object | ||
|
|
||
|
|
||
| XML Functions | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6851,6 +6851,18 @@ object functions { | |
| */ | ||
| def parse_json(json: Column): Column = Column.fn("parse_json", json) | ||
|
|
||
| /** | ||
| * Converts a column containing nested inputs (array/map/struct) into a variants where maps and | ||
| * structs are converted to variant objects which are unordered unlike SQL structs. Input maps can | ||
| * only have string keys. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the input has no array/map/struct, this function is noop?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, the type check will fail for it. |
||
| * | ||
| * @param col | ||
| * a column with a nested schema or column name. | ||
| * @group variant_funcs | ||
| * @since 4.0.0 | ||
| */ | ||
| def to_variant_object(col: Column): Column = Column.fn("to_variant_object", col) | ||
|
|
||
| /** | ||
| * Check if a variant value is a variant null. Returns true if and only if the input is a | ||
| * variant null and false otherwise (including in the case of SQL NULL). | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke | |
| import org.apache.spark.sql.catalyst.json.JsonInferSchema | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} | ||
| import org.apache.spark.sql.catalyst.trees.UnaryLike | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, QuotingUtils} | ||
| import org.apache.spark.sql.catalyst.util.DateTimeConstants._ | ||
| import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -117,6 +117,73 @@ case class IsVariantNull(child: Expression) extends UnaryExpression | |
| copy(child = newChild) | ||
| } | ||
|
|
||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(expr) - Convert a nested input (array/map/struct) into a variant where maps and structs are converted to variant objects which are unordered unlike SQL structs. Input maps can only have string keys.", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do we define the element order in the resulting variant object? Random?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes and no. From a logical perspective, the keys should be thought of as random and the users should not assume anything about the order. |
||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); | ||
| {"a":1,"b":2} | ||
| > SELECT _FUNC_(array(1, 2, 3)); | ||
| [1,2,3] | ||
| > SELECT _FUNC_(array(named_struct('a', 1))); | ||
| [{"a":1}] | ||
| > SELECT _FUNC_(array(map("a", 2))); | ||
| [{"a":2}] | ||
| """, | ||
| since = "4.0.0", | ||
| group = "variant_funcs") | ||
| // scalastyle:on line.size.limit | ||
| case class ToVariantObject(child: Expression) | ||
harshmotw-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| extends UnaryExpression | ||
| with NullIntolerant | ||
| with QueryErrorsBase { | ||
|
|
||
| override val dataType: DataType = VariantType | ||
|
|
||
| // Only accept nested types at the root but any types can be nested inside. | ||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| val checkResult: Boolean = child.dataType match { | ||
| case _: StructType | _: ArrayType | _: MapType => | ||
| VariantGet.checkDataType(child.dataType, allowStructsAndMaps = true) | ||
| case _ => false | ||
| } | ||
| if (!checkResult) { | ||
| DataTypeMismatch( | ||
| errorSubClass = "CAST_WITHOUT_SUGGESTION", | ||
| messageParameters = | ||
| Map("srcType" -> toSQLType(child.dataType), "targetType" -> toSQLType(VariantType))) | ||
| } else { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } | ||
| } | ||
|
|
||
| override def prettyName: String = "to_variant_object" | ||
|
|
||
| override protected def withNewChildInternal(newChild: Expression): ToVariantObject = | ||
| copy(child = newChild) | ||
|
|
||
| protected override def nullSafeEval(input: Any): Any = | ||
| VariantExpressionEvalUtils.castToVariant(input, child.dataType) | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
harshmotw-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| val childCode = child.genCode(ctx) | ||
| val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") | ||
| val fromArg = ctx.addReferenceObj("from", child.dataType) | ||
| val javaType = JavaCode.javaType(VariantType) | ||
| val code = | ||
| code""" | ||
| ${childCode.code} | ||
| boolean ${ev.isNull} = ${childCode.isNull}; | ||
| $javaType ${ev.value} = ${CodeGenerator.defaultValue(VariantType)}; | ||
| if (!${childCode.isNull}) { | ||
| ${ev.value} = $cls.castToVariant(${childCode.value}, $fromArg); | ||
| } | ||
| """ | ||
| ev.copy(code = code) | ||
| } | ||
| } | ||
|
|
||
| object VariantPathParser extends RegexParsers { | ||
| // A path segment in the `VariantGet` expression represents either an object key access or an | ||
| // array index access. | ||
|
|
@@ -260,13 +327,16 @@ case object VariantGet { | |
| * Returns whether a data type can be cast into/from variant. For scalar types, we allow a subset | ||
| * of them. For nested types, we reject map types with a non-string key type. | ||
| */ | ||
| def checkDataType(dataType: DataType): Boolean = dataType match { | ||
| def checkDataType(dataType: DataType, allowStructsAndMaps: Boolean = true): Boolean = | ||
| dataType match { | ||
| case _: NumericType | BooleanType | _: StringType | BinaryType | _: DatetimeType | | ||
| VariantType | _: DayTimeIntervalType | _: YearMonthIntervalType => | ||
| true | ||
| case ArrayType(elementType, _) => checkDataType(elementType) | ||
| case MapType(_: StringType, valueType, _) => checkDataType(valueType) | ||
| case StructType(fields) => fields.forall(f => checkDataType(f.dataType)) | ||
| case ArrayType(elementType, _) => checkDataType(elementType, allowStructsAndMaps) | ||
| case MapType(_: StringType, valueType, _) if allowStructsAndMaps => | ||
| checkDataType(valueType, allowStructsAndMaps) | ||
| case StructType(fields) if allowStructsAndMaps => | ||
| fields.forall(f => checkDataType(f.dataType, allowStructsAndMaps)) | ||
| case _ => false | ||
| } | ||
|
|
||
|
|
@@ -635,7 +705,7 @@ object VariantExplode { | |
| > SELECT _FUNC_(parse_json('null')); | ||
| VOID | ||
| > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]')); | ||
| ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>> | ||
| ARRAY<OBJECT<a: BIGINT, b: BOOLEAN>> | ||
| """, | ||
| since = "4.0.0", | ||
| group = "variant_funcs" | ||
|
|
@@ -666,7 +736,24 @@ object SchemaOfVariant { | |
| /** The actual implementation of the `SchemaOfVariant` expression. */ | ||
| def schemaOfVariant(input: VariantVal): UTF8String = { | ||
| val v = new Variant(input.getValue, input.getMetadata) | ||
| UTF8String.fromString(schemaOf(v).sql) | ||
| UTF8String.fromString(printSchema(schemaOf(v))) | ||
| } | ||
|
|
||
| /** | ||
| * Similar to `dataType.sql`. The only difference is that `StructType` is shown as | ||
| * `OBJECT<...>` rather than `STRUCT<...>`. | ||
harshmotw-db marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * SchemaOfVariant expressions use the Struct DataType to denote the Object type in the variant | ||
| * spec. However, the Object type is not equivalent to the struct type as an Object represents an | ||
| * unordered bag of key-value pairs while the Struct type is ordered. | ||
| */ | ||
| def printSchema(dataType: DataType): String = dataType match { | ||
| case StructType(fields) => | ||
| def printField(f: StructField): String = | ||
| s"${QuotingUtils.quoteIfNeeded(f.name)}: ${printSchema(f.dataType)}" | ||
|
|
||
| s"OBJECT<${fields.map(printField).mkString(", ")}>" | ||
| case ArrayType(elementType, _) => s"ARRAY<${printSchema(elementType)}>" | ||
| case _ => dataType.sql | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -731,7 +818,7 @@ object SchemaOfVariant { | |
| > SELECT _FUNC_(parse_json(j)) FROM VALUES ('1'), ('2'), ('3') AS tab(j); | ||
| BIGINT | ||
| > SELECT _FUNC_(parse_json(j)) FROM VALUES ('{"a": 1}'), ('{"b": true}'), ('{"c": 1.23}') AS tab(j); | ||
| STRUCT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)> | ||
| OBJECT<a: BIGINT, b: BOOLEAN, c: DECIMAL(3,2)> | ||
| """, | ||
| since = "4.0.0", | ||
| group = "variant_funcs") | ||
|
|
@@ -767,7 +854,8 @@ case class SchemaOfVariantAgg( | |
| override def merge(buffer: DataType, input: DataType): DataType = | ||
| SchemaOfVariant.mergeSchema(buffer, input) | ||
|
|
||
| override def eval(buffer: DataType): Any = UTF8String.fromString(buffer.sql) | ||
| override def eval(buffer: DataType): Any = | ||
| UTF8String.fromString(SchemaOfVariant.printSchema(buffer)) | ||
|
|
||
| override def serialize(buffer: DataType): Array[Byte] = buffer.json.getBytes("UTF-8") | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.