diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 32fdb3e5faf2..11edce8140f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -25,6 +25,7 @@ import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try @@ -35,7 +36,21 @@ import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ + ArrayBasedMapData, + BadRecordException, + CaseInsensitiveMap, + DateFormatter, + DropMalformedMode, + FailureSafeParser, + GenericArrayData, + MapData, + ParseMode, + PartialResultArrayException, + PartialResultException, + PermissiveMode, + TimestampFormatter +} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors @@ -69,6 +84,7 @@ class StaxXmlParser( private val decimalParser = ExprUtils.getDecimalParser(options.locale) + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis /** * Parses a single XML string and turns it into either one resulting row or no row (if the @@ -85,7 +101,7 @@ class StaxXmlParser( } private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { - if (SQLConf.get.caseSensitiveAnalysis) { + if (caseSensitive) { schema.map(_.name).zipWithIndex.toMap } else { CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) @@ -201,27 +217,30 @@ class StaxXmlParser( case (_: EndElement, _: DataType) => null case (c: Characters, ArrayType(st, _)) => // For `ArrayType`, it needs to return the type of element. The values are merged later. + parser.next convertTo(c.getData, st) case (c: Characters, st: StructType) => - // If a value tag is present, this can be an attribute-only element whose values is in that - // value tag field. Or, it can be a mixed-type element with both some character elements - // and other complex structure. Character elements are ignored. - val attributesOnly = st.fields.forall { f => - f.name == options.valueTag || f.name.startsWith(options.attributePrefix) - } - if (attributesOnly) { - // If everything else is an attribute column, there's no complex structure. - // Just return the value of the character element, or null if we don't have a value tag - st.find(_.name == options.valueTag).map( - valueTag => convertTo(c.getData, valueTag.dataType)).orNull - } else { - // Otherwise, ignore this character element, and continue parsing the following complex - // structure - parser.next - parser.peek match { - case _: EndElement => null // no struct here at all; done - case _ => convertObject(parser, st) - } + parser.next + parser.peek match { + case _: EndElement => + // It couldn't be an array of value tags + // as the opening tag is immediately followed by a closing tag. + if (c.isWhiteSpace) { + return null + } + val indexOpt = getFieldNameToIndex(st).get(options.valueTag) + indexOpt match { + case Some(index) => + convertTo(c.getData, st.fields(index).dataType) + case None => null + } + case _ => + val row = convertObject(parser, st) + if (!c.isWhiteSpace) { + addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, c.getData, addToTail = false) + } else { + row + } } case (_: Characters, _: StringType) => convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) @@ -237,6 +256,7 @@ class StaxXmlParser( case _ => convertField(parser, dataType, attributes) } case (c: Characters, dt: DataType) => + parser.next convertTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( @@ -262,7 +282,12 @@ class StaxXmlParser( case e: StartElement => kvPairs += (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType)) + convertField(parser, valueType)) + case c: Characters if !c.isWhiteSpace => + // Create a value tag field for it + kvPairs += + // TODO: We don't support an array value tags in map yet. + (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType)) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) case _ => // do nothing @@ -343,8 +368,9 @@ class StaxXmlParser( val row = new Array[Any](schema.length) val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. - convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => - nameToIndex.get(f).foreach { row(_) = v } + convertAttributes(rootAttributes, schema).toSeq.foreach { + case (f, v) => + nameToIndex.get(f).foreach { row(_) = v } } val wildcardColName = options.wildcardColName @@ -405,15 +431,11 @@ class StaxXmlParser( badRecordException = badRecordException.orElse(Some(e)) } - case c: Characters if !c.isWhiteSpace && isRootAttributesOnly => - nameToIndex.get(options.valueTag) match { - case Some(index) => - row(index) = convertTo(c.getData, schema(index).dataType) - case None => // do nothing - } + case c: Characters if !c.isWhiteSpace => + addOrUpdate(row, schema, options.valueTag, c.getData) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = parseAndCheckEndElement(row, schema, parser) case _ => // do nothing } @@ -576,6 +598,54 @@ class StaxXmlParser( castTo(data, FloatType).asInstanceOf[Float] } } + + @tailrec + private def parseAndCheckEndElement( + row: Array[Any], + schema: StructType, + parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !c.isWhiteSpace => + parser.nextEvent() + addOrUpdate(row, schema, options.valueTag, c.getData) + parseAndCheckEndElement(row, schema, parser) + case _ => + parser.nextEvent() + parseAndCheckEndElement(row, schema, parser) + } + } + + private def addOrUpdate( + row: Array[Any], + schema: StructType, + name: String, + data: String, + addToTail: Boolean = true): InternalRow = { + schema.getFieldIndex(name) match { + case Some(index) => + schema(index).dataType match { + case ArrayType(elementType, _) => + val value = convertTo(data, elementType) + val result = if (row(index) == null) { + ArrayBuffer(value) + } else { + val genericArrayData = row(index).asInstanceOf[GenericArrayData] + if (addToTail) { + genericArrayData.toArray(elementType) :+ value + } else { + value +: genericArrayData.toArray(elementType) + } + } + row(index) = new GenericArrayData(result) + case dataType => + row(index) = convertTo(data, dataType) + } + case None => // do nothing + } + InternalRow.fromSeq(row.toIndexedSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index de8ec33de0ce..9d0c16d95e46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -164,7 +164,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } - @tailrec private def inferField(parser: XMLEventReader): DataType = { parser.peek match { case _: EndElement => NullType @@ -182,18 +181,25 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case _ => inferField(parser) } case c: Characters if !c.isWhiteSpace => - // This could be the characters of a character-only element, or could have mixed - // characters and other complex structure val characterType = inferFrom(c.getData) parser.nextEvent() parser.peek match { case _: StartElement => - // Some more elements follow; so ignore the characters. - // Use the schema of the rest - inferObject(parser).asInstanceOf[StructType] + // Some more elements follow; + // This is a mix of values and other elements + val innerType = inferObject(parser).asInstanceOf[StructType] + addOrUpdateValueTagType(innerType, characterType) case _ => - // That's all, just the character-only body; use that as the type - characterType + val fieldType = inferField(parser) + fieldType match { + case st: StructType => addOrUpdateValueTagType(st, characterType) + case _: NullType => characterType + case _: DataType => + // The field type couldn't be an array type + new StructType() + .add(options.valueTag, addOrUpdateType(Some(characterType), fieldType)) + + } } case e: XMLEvent => throw new IllegalArgumentException(s"Failed to parse data with unexpected event $e") @@ -229,17 +235,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val nameToDataType = collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) - def addOrUpdateType(fieldName: String, newType: DataType): Unit = { - val oldTypeOpt = nameToDataType.get(fieldName) - oldTypeOpt match { - // If the field name exists in the map, - // merge the type and infer the combined field as an array type if necessary - case Some(oldType) if !oldType.isInstanceOf[ArrayType] => - nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) - case Some(oldType) => - nameToDataType.update(fieldName, compatibleType(oldType, newType)) - case None => - nameToDataType.put(fieldName, newType) + @tailrec + def inferAndCheckEndElement(parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !c.isWhiteSpace => + val characterType = inferFrom(c.getData) + parser.nextEvent() + addOrUpdateType(nameToDataType, options.valueTag, characterType) + inferAndCheckEndElement(parser) + case _ => + parser.nextEvent() + inferAndCheckEndElement(parser) } } @@ -248,7 +256,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - addOrUpdateType(f, inferFrom(v)) + addOrUpdateType(nameToDataType, f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -281,29 +289,19 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - addOrUpdateType(field, inferredType) + addOrUpdateType(nameToDataType, field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - addOrUpdateType(options.valueTag, valueTagType) + addOrUpdateType(nameToDataType, options.valueTag, valueTagType) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = inferAndCheckEndElement(parser) case _ => // do nothing } } - // A structure object is an attribute-only element - // if it only consists of attributes and valueTags. - // If not, we will remove the valueTag field from the schema - val attributesOnly = nameToDataType.forall { - case (fieldName, _) => - fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) - } - if (!attributesOnly) { - nameToDataType -= options.valueTag - } // Note: other code relies on this sorting for correctness, so don't remove it! StructType(nameToDataType.map{ @@ -534,4 +532,75 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } } + + /** + * This helper function merges the data type of value tags and inner elements. + * It could only be structure data. Consider the following case, + * + * value1 + * 1 + * value2 + * + * Input: ''a struct'' and ''_VALUE string'' + * Return: ''a struct>'' + * @param objectType inner elements' type + * @param valueTagType value tag's type + */ + private[xml] def addOrUpdateValueTagType( + objectType: DataType, + valueTagType: DataType): DataType = { + (objectType, valueTagType) match { + case (st: StructType, _) => + val valueTagIndexOpt = st.getFieldIndex(options.valueTag) + + valueTagIndexOpt match { + // If the field name exists in the inner elements, + // merge the type and infer the combined field as an array type if necessary + case Some(index) if !st(index).dataType.isInstanceOf[ArrayType] => + updateStructField( + st, + index, + ArrayType(compatibleType(st(index).dataType, valueTagType))) + case Some(index) => + updateStructField(st, index, compatibleType(st(index).dataType, valueTagType)) + case None => + st.add(options.valueTag, valueTagType) + } + case _ => + throw new IllegalStateException( + "illegal state when merging value tags types in schema inference" + ) + } + } + + private def updateStructField( + structType: StructType, + index: Int, + newType: DataType): StructType = { + val newFields: Array[StructField] = + structType.fields.updated(index, structType.fields(index).copy(dataType = newType)) + StructType(newFields) + } + + private def addOrUpdateType( + nameToDataType: collection.mutable.TreeMap[String, DataType], + fieldName: String, + newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + val mergedType = addOrUpdateType(oldTypeOpt, newType) + nameToDataType.put(fieldName, mergedType) + } + + private def addOrUpdateType(oldTypeOpt: Option[DataType], newType: DataType): DataType = { + oldTypeOpt match { + // If the field name already exists, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] && !newType.isInstanceOf[NullType] => + ArrayType(compatibleType(oldType, newType)) + case Some(oldType) => + compatibleType(oldType, newType) + case None => + newType + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index b3e8e3c79384..4b9a95856afb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDateTime} import java.util.TimeZone +import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ @@ -1145,7 +1146,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children.xml")) val mixedRow = mixedDF.head() - assert(mixedRow.getAs[Row](0).toSeq === Seq(" lorem ")) + assert(mixedRow.getAs[Row](0) === Row(List(" issue ", " text ignored "), " lorem ")) assert(mixedRow.getString(1) === " ipsum ") } @@ -1729,9 +1730,15 @@ class XmlSuite extends QueryTest with SharedSparkSession { val TAG_NAME = "tag" val VALUETAG_NAME = "_VALUE" val schema = buildSchema( + field(VALUETAG_NAME), field(ATTRIBUTE_NAME), - field(TAG_NAME, LongType), - field(VALUETAG_NAME)) + field(TAG_NAME, LongType)) + val expectedAns = Seq( + Row("value1", null, null), + Row("value2", "attr1", null), + Row("4", null, 5L), + Row("7", null, 6L), + Row(null, "8", null)) val dfs = Seq( // user specified schema spark.read @@ -1744,25 +1751,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(getTestResourcePath(resDir + "root-level-value-none.xml")) ) dfs.foreach { df => - val result = df.collect() - assert(result.length === 5) - assert(result(0).get(0) == null && result(0).get(1) == null) - assert( - result(1).getAs[String](ATTRIBUTE_NAME) == "attr1" - && result(1).getAs[Any](TAG_NAME) == null - ) - assert( - result(2).getAs[Long](TAG_NAME) == 5L - && result(2).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(3).getAs[Long](TAG_NAME) == 6L - && result(3).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(4).getAs[String](ATTRIBUTE_NAME) == "8" - && result(4).getAs[Any](TAG_NAME) == null - ) + checkAnswer(df, expectedAns) } } @@ -2371,4 +2360,248 @@ class XmlSuite extends QueryTest with SharedSparkSession { } } } + + test("capture values interspersed between elements - simple") { + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | value4 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Array("value1", "value4"), Row(Array("value2", "value3"), 1)))) + } + + test("capture values interspersed between elements - array") { + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | + | value4 + | 2 + | value5 + | 3 + | value6 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val expectedAns = Seq( + Row( + "value1", + Array( + Row(List("value2", "value3"), 1, null), + Row(List("value4", "value5", "value6"), 2, 3)))) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, expectedAns) + + } + + test("capture values interspersed between elements - long and double") { + val xmlString = + s""" + | + | + | 1 + | 2 + | 3 + | 4 + | 5.0 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1.0, 3.0, 5.0), Array(2, 4))))) + } + + test("capture values interspersed between elements - comments") { + val xmlString = + s""" + | + | 1 2 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2))))) + } + + test("capture values interspersed between elements - whitespaces with quotes") { + val xmlString = + s""" + | + | " " + | " "1 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", false) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq( + Row("\" \"", Row(1, "\" \""), Row(Row(null, " "))))) + } + + test("capture values interspersed between elements - nested comments") { + val xmlString = + s""" + | + | 1 + | 2 + | 1 + | 3 + | 2 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2, 3), Array(1, 2))))) + } + + test("capture values interspersed between elements - nested struct") { + val xmlString = + s""" + | + | + | + | 1 + | value1 + | 2 + | value2 + | 3 + | + | value4 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer( + df, + Seq( + Row( + Row( + "value4", + Row( + Array("value1", "value2"), + Array(1, 2), + 3))))) + } + + test("capture values interspersed between elements - deeply nested") { + val xmlString = + s""" + | + | value1 + | + | value2 + | + | value3 + | + | value4 + | + | value5 + | 1 + | value6 + | 2 + | value7 + | + | value8 + | string + | value9 + | + | value10 + | + | + | 3 + | value11 + | 4 + | + | string + | value12 + | + | value13 + | 3 + | value14 + | + | value15 + | + | value16 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("ignoreSurroundingSpaces", true) + .option("rowTag", "ROW") + .option("multiLine", "true") + .xml(input) + + val expectedAns = Seq(Row( + ArraySeq("value1", "value16"), + Row( + ArraySeq("value2", "value15"), + Row( + ArraySeq("value3", "value10", "value13", "value14"), + Array( + Row( + ArraySeq("value4", "value8", "value9"), + "string", + Row(ArraySeq("value5", "value6", "value7"), ArraySeq(1, 2))), + Row( + ArraySeq("value12"), + "string", + Row(ArraySeq("value11"), ArraySeq(3, 4)))), + 3)))) + + checkAnswer(df, expectedAns) + } }