From 52bb404885ab18ce1400172d1d57b8ca7bc60d93 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 11 Dec 2023 15:48:38 -0800 Subject: [PATCH 01/19] init --- .../sql/catalyst/xml/StaxXmlParser.scala | 174 ++++++++++++------ .../sql/catalyst/xml/StaxXmlParserUtils.scala | 2 - .../sql/catalyst/xml/XmlInferSchema.scala | 97 ++++++++-- .../spark/sql/catalyst/xml/XmlOptions.scala | 2 +- .../test-data/xml-resources/values-array.xml | 18 ++ .../test-data/xml-resources/values-nested.xml | 15 ++ .../test-data/xml-resources/values-simple.xml | 11 ++ .../execution/datasources/xml/XmlSuite.scala | 85 ++++++--- 8 files changed, 302 insertions(+), 102 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-array.xml create mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-nested.xml create mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-simple.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 567074bbf126..d68a121d3e01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -24,18 +24,30 @@ import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema - import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try import scala.util.control.NonFatal import scala.xml.SAXException - import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.ExprUtils -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException, CaseInsensitiveMap, DateFormatter, DropMalformedMode, FailureSafeParser, GenericArrayData, MapData, ParseMode, PartialResultArrayException, PartialResultException, PermissiveMode, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{ + ArrayBasedMapData, + BadRecordException, + CaseInsensitiveMap, + DateFormatter, + DropMalformedMode, + FailureSafeParser, + GenericArrayData, + MapData, + ParseMode, + PartialResultArrayException, + PartialResultException, + PermissiveMode, + TimestampFormatter +} import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream import org.apache.spark.sql.errors.QueryExecutionErrors @@ -43,25 +55,28 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -class StaxXmlParser( - schema: StructType, - val options: XmlOptions) extends Logging { +import scala.annotation.tailrec + +class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging { private lazy val timestampFormatter = TimestampFormatter( options.timestampFormatInRead, options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - isParsing = true) + isParsing = true + ) private lazy val dateFormatter = DateFormatter( options.dateFormatInRead, options.locale, legacyFormat = FAST_DATE_FORMAT, - isParsing = true) + isParsing = true + ) private val decimalParser = ExprUtils.getDecimalParser(options.locale) + private val caseSensitive = SQLConf.get.caseSensitiveAnalysis /** * Parses a single XML string and turns it into either one resulting row or no row (if the @@ -69,8 +84,8 @@ class StaxXmlParser( */ val parse: String => Option[InternalRow] = { // This is intentionally a val to create a function once and reuse. - if (schema.isEmpty) { - (_: String) => Some(InternalRow.empty) + if (schema.isEmpty) { (_: String) => + Some(InternalRow.empty) } else { val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) (input: String) => doParseColumn(input, options.parseMode, xsdSchema) @@ -78,7 +93,7 @@ class StaxXmlParser( } private def getFieldNameToIndex(schema: StructType): Map[String, Int] = { - if (SQLConf.get.caseSensitiveAnalysis) { + if (caseSensitive) { schema.map(_.name).zipWithIndex.toMap } else { CaseInsensitiveMap(schema.map(_.name).zipWithIndex.toMap) @@ -149,18 +164,11 @@ class StaxXmlParser( |""".stripMargin + e.getMessage val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) - throw BadRecordException(() => xmlRecord, () => Array.empty, - wrappedCharException) + throw BadRecordException(() => xmlRecord, () => Array.empty, wrappedCharException) case PartialResultException(row, cause) => - throw BadRecordException( - record = () => xmlRecord, - partialResults = () => Array(row), - cause) + throw BadRecordException(record = () => xmlRecord, partialResults = () => Array(row), cause) case PartialResultArrayException(rows, cause) => - throw BadRecordException( - record = () => xmlRecord, - partialResults = () => rows, - cause) + throw BadRecordException(record = () => xmlRecord, partialResults = () => rows, cause) } } @@ -192,30 +200,34 @@ class StaxXmlParser( case (_: EndElement, _: DataType) => null case (c: Characters, ArrayType(st, _)) => // For `ArrayType`, it needs to return the type of element. The values are merged later. + parser.next convertTo(c.getData, st) case (c: Characters, st: StructType) => - // If a value tag is present, this can be an attribute-only element whose values is in that - // value tag field. Or, it can be a mixed-type element with both some character elements - // and other complex structure. Character elements are ignored. - val attributesOnly = st.fields.forall { f => - f.name == options.valueTag || f.name.startsWith(options.attributePrefix) - } - if (attributesOnly) { - // If everything else is an attribute column, there's no complex structure. - // Just return the value of the character element, or null if we don't have a value tag - st.find(_.name == options.valueTag).map( - valueTag => convertTo(c.getData, valueTag.dataType)).orNull - } else { - // Otherwise, ignore this character element, and continue parsing the following complex - // structure - parser.next - parser.peek match { - case _: EndElement => null // no struct here at all; done - case _ => convertObject(parser, st) - } + parser.next + parser.peek match { + case _: EndElement => + // TODO: optimize it + // TODO: array of value tag + if (!isEmptyString(c)) { + val indexOpt = getFieldNameToIndex(st).get(options.valueTag) + indexOpt.map { index => + // TODO: optimize it + convertTo(c.getData, st.fields(index).dataType) + }.orNull + } else { + null + } + case _ => + val row = convertObject(parser, st) + if (!isEmptyString(c)) { + addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, c.getData, addToTail = false) + } else { + row + } } case (_: Characters, _: StringType) => convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) + // TODO: can we remove it? case (c: Characters, _: DataType) if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. @@ -228,10 +240,12 @@ class StaxXmlParser( case _ => convertField(parser, dataType, attributes) } case (c: Characters, dt: DataType) => + parser.next convertTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( - s"Failed to parse a value for data type $dt with event ${e.toString}") + s"Failed to parse a value for data type $dt with event ${e.toString}" + ) } } @@ -244,16 +258,21 @@ class StaxXmlParser( attributes: Array[Attribute]): MapData = { val kvPairs = ArrayBuffer.empty[(UTF8String, Any)] attributes.foreach { attr => - kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart) - -> convertTo(attr.getValue, valueType)) + kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart) + -> convertTo(attr.getValue, valueType)) } var shouldStop = false while (!shouldStop) { parser.nextEvent match { case e: StartElement => kvPairs += - (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType)) + (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> + convertField(parser, valueType)) + case c: Characters if !isEmptyString(c) => + // Create a value tag field for it + kvPairs += + // TODO: potential mismatch? + (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType)) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) case _ => // do nothing @@ -334,8 +353,9 @@ class StaxXmlParser( val row = new Array[Any](schema.length) val nameToIndex = getFieldNameToIndex(schema) // If there are attributes, then we process them first. - convertAttributes(rootAttributes, schema).toSeq.foreach { case (f, v) => - nameToIndex.get(f).foreach { row(_) = v } + convertAttributes(rootAttributes, schema).toSeq.foreach { + case (f, v) => + nameToIndex.get(f).foreach { row(_) = v } } val wildcardColName = options.wildcardColName @@ -396,15 +416,11 @@ class StaxXmlParser( badRecordException = badRecordException.orElse(Some(e)) } - case c: Characters if !c.isWhiteSpace && isRootAttributesOnly => - nameToIndex.get(options.valueTag) match { - case Some(index) => - row(index) = convertTo(c.getData, schema(index).dataType) - case None => // do nothing - } + case c: Characters if !isEmptyString(c) => + addOrUpdate(row, schema, options.valueTag, c.getData) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = parseAndCheckEndElement(row, schema, parser) case _ => // do nothing } @@ -565,6 +581,56 @@ class StaxXmlParser( castTo(data, FloatType).asInstanceOf[Float] } } + private[xml] def isEmptyString(c: Characters): Boolean = c.getData.trim.isEmpty + + @tailrec + private def parseAndCheckEndElement( + row: Array[Any], + schema: StructType, + parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !isEmptyString(c) => + parser.nextEvent() + addOrUpdate(row, schema, options.valueTag, c.getData) + parseAndCheckEndElement(row, schema, parser) + case _ => + parser.nextEvent() + parseAndCheckEndElement(row, schema, parser) + } + } + + private def addOrUpdate( + row: Array[Any], + schema: StructType, + name: String, + string: String, + addToTail: Boolean = true): InternalRow = { + schema.getFieldIndex(name) match { + case Some(index) => + schema(index).dataType match { + case arr @ ArrayType(elementType, _) => + val value = convertTo(string, elementType) + val result = if (row(index) == null) { + ArrayBuffer(value) + } else { + // TODO(shujing): optimization? + if (addToTail) { + row(index).asInstanceOf[GenericArrayData].toArray(elementType) :+ value + } else { + value +: row(index).asInstanceOf[GenericArrayData].toArray(elementType) + } + } + row(index) = new GenericArrayData(result) + case dataType => + row(index) = convertTo(string, dataType) + } + case None => // do nothing + } + // TODO(shujing): optimization? + InternalRow.fromSeq(row.toIndexedSeq) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 0471cb310d89..110514718187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -21,7 +21,6 @@ import javax.xml.namespace.QName import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants} import javax.xml.stream.events._ -import scala.annotation.tailrec import scala.jdk.CollectionConverters._ object StaxXmlParserUtils { @@ -70,7 +69,6 @@ object StaxXmlParserUtils { /** * Checks if current event points the EndElement. */ - @tailrec def checkEndElement(parser: XMLEventReader): Boolean = { parser.peek match { case _: EndElement | _: EndDocument => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index f81c476cd385..b0aba1f2ea95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -22,13 +22,11 @@ import javax.xml.stream.XMLEventReader import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema - import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.control.Exception._ import scala.util.control.NonFatal - import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.ExprUtils @@ -157,7 +155,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) parser.peek match { case _: EndElement => NullType case _: StartElement => inferObject(parser) - case c: Characters if c.isWhiteSpace => + case c: Characters if isEmptyString(c) => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. val data = c.getData @@ -169,16 +167,18 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case _: EndElement => StringType case _ => inferField(parser) } - case c: Characters if !c.isWhiteSpace => + // what about new line character + case c: Characters if !isEmptyString(c) => // This could be the characters of a character-only element, or could have mixed // characters and other complex structure val characterType = inferFrom(c.getData) parser.nextEvent() parser.peek match { case _: StartElement => - // Some more elements follow; so ignore the characters. - // Use the schema of the rest - inferObject(parser).asInstanceOf[StructType] + // Some more elements follow; + // This is a mix of values and other elements + val innerType = inferObject(parser).asInstanceOf[StructType] + addOrUpdateValueTagType(innerType, characterType) case _ => // That's all, just the character-only body; use that as the type characterType @@ -231,6 +231,22 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } + @tailrec + def inferAndCheckEndElement(parser: XMLEventReader): Boolean = { + parser.peek match { + case _: EndElement | _: EndDocument => true + case _: StartElement => false + case c: Characters if !isEmptyString(c) => + val characterType = inferFrom(c.getData) + parser.nextEvent() + addOrUpdateType(options.valueTag, characterType) + inferAndCheckEndElement(parser) + case _ => + parser.nextEvent() + inferAndCheckEndElement(parser) + } + } + // If there are attributes, then we should process them first. val rootValuesMap = StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) @@ -271,27 +287,17 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) addOrUpdateType(field, inferredType) - case c: Characters if !c.isWhiteSpace => + case c: Characters if !isEmptyString(c) => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) addOrUpdateType(options.valueTag, valueTagType) case _: EndElement => - shouldStop = StaxXmlParserUtils.checkEndElement(parser) + shouldStop = inferAndCheckEndElement(parser) case _ => // do nothing } } - // A structure object is an attribute-only element - // if it only consists of attributes and valueTags. - // If not, we will remove the valueTag field from the schema - val attributesOnly = nameToDataType.forall { - case (fieldName, _) => - fieldName == options.valueTag || fieldName.startsWith(options.attributePrefix) - } - if (!attributesOnly) { - nameToDataType -= options.valueTag - } // Note: other code relies on this sorting for correctness, so don't remove it! StructType(nameToDataType.map{ @@ -503,4 +509,57 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } } + + /** + * This helper function merges the data type of value tags and inner elements. + * It could only be structure data. Consider the following case, + * + * value1 + * 1 + * value2 + * + * Input: ''a struct'' and ''_VALUE string'' + * Return: ''a struct>'' + * @param objectType inner elements' type + * @param valueTagType value tag's type + */ + private[xml] def addOrUpdateValueTagType( + objectType: DataType, + valueTagType: DataType): DataType = { + (objectType, valueTagType) match { + case (st: StructType, _) => + // TODO(shujing): case sensitive? + val valueTagIndexOpt = st.getFieldIndex(options.valueTag) + + valueTagIndexOpt match { + // If the field name exists in the inner elements, + // merge the type and infer the combined field as an array type if necessary + case Some(index) if !st(index).dataType.isInstanceOf[ArrayType] => + updateStructField( + st, + index, + ArrayType(compatibleType(st(index).dataType, valueTagType))) + case Some(index) => + updateStructField(st, index, compatibleType(st(index).dataType, valueTagType)) + case None => + st.add(options.valueTag, valueTagType) + } + case _ => + throw new IllegalStateException( + "illegal state when merging value tags types in schema inference" + ) + } + } + + private[xml] def isEmptyString(c: Characters): Boolean = c.getData.trim.isEmpty + + private def updateStructField( + structType: StructType, + index: Int, + newType: DataType): StructType = { + val newFields: Array[StructField] = + structType.fields.updated(index, structType.fields(index).copy(dataType = newType)) + StructType(newFields) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 92b156fb8f23..218d56c0f203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -95,7 +95,7 @@ class XmlOptions( val nullValue = parameters.getOrElse(NULL_VALUE, XmlOptions.DEFAULT_NULL_VALUE) val columnNameOfCorruptRecord = parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord) - val ignoreSurroundingSpaces = getBool(IGNORE_SURROUNDING_SPACES, false) + val ignoreSurroundingSpaces = getBool(IGNORE_SURROUNDING_SPACES, true) val parseMode = ParseMode.fromString(parameters.getOrElse(MODE, PermissiveMode.name)) val inferSchema = getBool(INFER_SCHEMA, true) val rowValidationXSDPath = parameters.get(ROW_VALIDATION_XSD_PATH).orNull diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-array.xml b/sql/core/src/test/resources/test-data/xml-resources/values-array.xml new file mode 100644 index 000000000000..d53d2db31538 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/values-array.xml @@ -0,0 +1,18 @@ + + + + value1 + + value2 + 1 + value3 + + + value4 + 2 + value5 + 3 + value6 + + + diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml b/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml new file mode 100644 index 000000000000..f5d70daa76d9 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml @@ -0,0 +1,15 @@ + + + + + + 1 + value1 + 2 + value2 + 3 + + value4 + + + diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml b/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml new file mode 100644 index 000000000000..469c38a25c36 --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml @@ -0,0 +1,11 @@ + + + + value1 + + value2 + 1 + value3 + + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index ee970806632b..42f9937ade0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -731,7 +731,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .collect() assert(results(0) === Row("alice", "35")) - assert(results(1) === Row("bob", " ")) + assert(results(1) === Row("bob", "")) assert(results(2) === Row("coc", "24")) } @@ -817,7 +817,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(result(0) === Row(Row(null))) assert(result(1) === Row(Row(Row(null, null)))) assert(result(2) === Row(Row(Row("E", null)))) - assert(result(3) === Row(Row(Row("E", " ")))) + assert(result(3) === Row(Row(Row("E", "")))) assert(result(4) === Row(Row(Row("E", "")))) } @@ -1145,8 +1145,8 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children.xml")) val mixedRow = mixedDF.head() - assert(mixedRow.getAs[Row](0).toSeq === Seq(" lorem ")) - assert(mixedRow.getString(1) === " ipsum ") + assert(mixedRow.getAs[Row](0).toSeq === Seq(Array(), "lorem")) + assert(mixedRow.getString(1) === "ipsum") } test("test mixed text and complex element children") { @@ -1154,9 +1154,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("rowTag", "root") .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children_2.xml")) - assert(mixedDF.select("foo.bar").head().getString(0) === " lorem ") + assert(mixedDF.select("foo.bar").head().getString(0) === "lorem") assert(mixedDF.select("foo.baz.bing").head().getLong(0) === 2) - assert(mixedDF.select("missing").head().getString(0) === " ipsum ") + assert(mixedDF.select("missing").head().getString(0) === "ipsum") } test("test XSD validation") { @@ -1720,7 +1720,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(result(1).getAs[String]("_attr") == "attr1" && result(1).getAs[String]("_VALUE") == "value2") // comments aren't included in valueTag - assert(result(2).getAs[String]("_VALUE") == "\n value3\n ") + assert(result(2).getAs[String]("_VALUE") == "value3") } } @@ -1732,6 +1732,13 @@ class XmlSuite extends QueryTest with SharedSparkSession { field(ATTRIBUTE_NAME), field(TAG_NAME, LongType), field(VALUETAG_NAME)) + val expectedAns = Seq( + Row(null, null, "value1"), + Row("attr1", null, "value2"), + Row(null, 5L, "4"), + Row(null, 6L, "7"), + Row("8", null, null) + ) val dfs = Seq( // user specified schema spark.read @@ -1744,25 +1751,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(getTestResourcePath(resDir + "root-level-value-none.xml")) ) dfs.foreach { df => - val result = df.collect() - assert(result.length === 5) - assert(result(0).get(0) == null && result(0).get(1) == null) - assert( - result(1).getAs[String](ATTRIBUTE_NAME) == "attr1" - && result(1).getAs[Any](TAG_NAME) == null - ) - assert( - result(2).getAs[Long](TAG_NAME) == 5L - && result(2).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(3).getAs[Long](TAG_NAME) == 6L - && result(3).getAs[Any](ATTRIBUTE_NAME) == null - ) - assert( - result(4).getAs[String](ATTRIBUTE_NAME) == "8" - && result(4).getAs[Any](TAG_NAME) == null - ) + checkAnswer(df, expectedAns) } } @@ -2178,4 +2167,48 @@ class XmlSuite extends QueryTest with SharedSparkSession { ) testWriteReadRoundTrip(df, Map("nullValue" -> "null", "prefersDecimal" -> "true")) } + + test("capture values interspersed between elements - simple") { + val df = spark.read.format("xml") + .option("rowTag", "ROW") + .option("multiLine", "true") + .load(getTestResourcePath(resDir + "values-simple.xml")) + + checkAnswer(df, Seq(Row("value1", Row(Array("value2", "value3"), 1)))) + } + + test("capture values interspersed between elements - array") { + val expectedAns = Seq( + Row( + "value1", + Array( + Row(List("value2", "value3"), 1, null), + Row(List("value4", "value5", "value6"), 2, 3)))) + val df = spark.read + .format("xml") + .option("rowTag", "ROW") + .option("multiLine", "true") + .load(getTestResourcePath(resDir + "values-array.xml")) + + checkAnswer(df, expectedAns) + + } + + test("capture values interspersed between elements - nested struct") { + val df = spark.read + .format("xml") + .option("rowTag", "ROW") + .option("multiLine", "true") + .load(getTestResourcePath(resDir + "values-nested.xml")) + + checkAnswer( + df, + Seq( + Row( + "value4", + Row( + Array("value1", "value2"), + Array(1, 2), + 3)))) + } } From 4f636178508c74cdfd9593c873e4ca0ec299e119 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 11 Dec 2023 15:51:18 -0800 Subject: [PATCH 02/19] revert format --- .../sql/catalyst/xml/StaxXmlParser.scala | 40 +++++++++++-------- .../execution/datasources/xml/XmlSuite.scala | 3 +- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 79166da6ee8e..2d28822ac0d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -24,11 +24,13 @@ import javax.xml.stream.{XMLEventReader, XMLStreamException} import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema + import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try import scala.util.control.NonFatal import scala.xml.SAXException + import org.apache.spark.SparkUpgradeException import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -57,22 +59,22 @@ import org.apache.spark.unsafe.types.UTF8String import scala.annotation.tailrec -class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging { +class StaxXmlParser( + schema: StructType, + val options: XmlOptions) extends Logging { private lazy val timestampFormatter = TimestampFormatter( options.timestampFormatInRead, options.zoneId, options.locale, legacyFormat = FAST_DATE_FORMAT, - isParsing = true - ) + isParsing = true) private lazy val dateFormatter = DateFormatter( options.dateFormatInRead, options.locale, legacyFormat = FAST_DATE_FORMAT, - isParsing = true - ) + isParsing = true) private val decimalParser = ExprUtils.getDecimalParser(options.locale) @@ -84,8 +86,8 @@ class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging */ val parse: String => Option[InternalRow] = { // This is intentionally a val to create a function once and reuse. - if (schema.isEmpty) { (_: String) => - Some(InternalRow.empty) + if (schema.isEmpty) { + (_: String) => Some(InternalRow.empty) } else { val xsdSchema = Option(options.rowValidationXSDPath).map(ValidatorUtil.getSchema) (input: String) => doParseColumn(input, options.parseMode, xsdSchema) @@ -166,11 +168,18 @@ class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging |""".stripMargin + e.getMessage val wrappedCharException = new CharConversionException(msg) wrappedCharException.initCause(e) - throw BadRecordException(() => xmlRecord, () => Array.empty, wrappedCharException) + throw BadRecordException(() => xmlRecord, () => Array.empty, + wrappedCharException) case PartialResultException(row, cause) => - throw BadRecordException(record = () => xmlRecord, partialResults = () => Array(row), cause) + throw BadRecordException( + record = () => xmlRecord, + partialResults = () => Array(row), + cause) case PartialResultArrayException(rows, cause) => - throw BadRecordException(record = () => xmlRecord, partialResults = () => rows, cause) + throw BadRecordException( + record = () => xmlRecord, + partialResults = () => rows, + cause) } } @@ -246,8 +255,7 @@ class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging convertTo(c.getData, dt) case (e: XMLEvent, dt: DataType) => throw new IllegalArgumentException( - s"Failed to parse a value for data type $dt with event ${e.toString}" - ) + s"Failed to parse a value for data type $dt with event ${e.toString}") } } @@ -260,16 +268,16 @@ class StaxXmlParser(schema: StructType, val options: XmlOptions) extends Logging attributes: Array[Attribute]): MapData = { val kvPairs = ArrayBuffer.empty[(UTF8String, Any)] attributes.foreach { attr => - kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart) - -> convertTo(attr.getValue, valueType)) + kvPairs += (UTF8String.fromString(options.attributePrefix + attr.getName.getLocalPart) + -> convertTo(attr.getValue, valueType)) } var shouldStop = false while (!shouldStop) { parser.nextEvent match { case e: StartElement => kvPairs += - (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> - convertField(parser, valueType)) + (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> + convertField(parser, valueType)) case c: Characters if !isEmptyString(c) => // Create a value tag field for it kvPairs += diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 42f9937ade0c..381f5ea58bc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1737,8 +1737,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { Row("attr1", null, "value2"), Row(null, 5L, "4"), Row(null, 6L, "7"), - Row("8", null, null) - ) + Row("8", null, null)) val dfs = Seq( // user specified schema spark.read From 815859f7ff9578304ceb594ef966db16a79455e7 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 11 Dec 2023 17:14:37 -0800 Subject: [PATCH 03/19] fix --- .../sql/catalyst/xml/StaxXmlParser.scala | 8 +++++- .../execution/datasources/xml/XmlSuite.scala | 25 ++++++++++--------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 2d28822ac0d5..839ffab37f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -591,7 +591,13 @@ class StaxXmlParser( castTo(data, FloatType).asInstanceOf[Float] } } - private[xml] def isEmptyString(c: Characters): Boolean = c.getData.trim.isEmpty + private[xml] def isEmptyString(c: Characters): Boolean = { + if (options.ignoreSurroundingSpaces) { + c.getData.trim.isEmpty + } else { + c.isWhiteSpace + } + } @tailrec private def parseAndCheckEndElement( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 381f5ea58bc9..acfb55ebd584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1145,7 +1145,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children.xml")) val mixedRow = mixedDF.head() - assert(mixedRow.getAs[Row](0).toSeq === Seq(Array(), "lorem")) + assert(mixedRow.getAs[Row](0) === Row(List("issue", "text ignored"), "lorem")) assert(mixedRow.getString(1) === "ipsum") } @@ -1729,15 +1729,15 @@ class XmlSuite extends QueryTest with SharedSparkSession { val TAG_NAME = "tag" val VALUETAG_NAME = "_VALUE" val schema = buildSchema( + field(VALUETAG_NAME), field(ATTRIBUTE_NAME), - field(TAG_NAME, LongType), - field(VALUETAG_NAME)) + field(TAG_NAME, LongType)) val expectedAns = Seq( - Row(null, null, "value1"), - Row("attr1", null, "value2"), - Row(null, 5L, "4"), - Row(null, 6L, "7"), - Row("8", null, null)) + Row("value1", null, null), + Row("value2", "attr1", null), + Row("4", null, 5L), + Row("7", null, 6L), + Row(null, "8", null)) val dfs = Seq( // user specified schema spark.read @@ -2204,10 +2204,11 @@ class XmlSuite extends QueryTest with SharedSparkSession { df, Seq( Row( - "value4", Row( - Array("value1", "value2"), - Array(1, 2), - 3)))) + "value4", + Row( + Array("value1", "value2"), + Array(1, 2), + 3))))) } } From 4be69e388d0c33e60bef3f088cd711dd1be2beb0 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 12 Dec 2023 10:59:31 -0800 Subject: [PATCH 04/19] rm todo --- .../sql/catalyst/xml/StaxXmlParser.scala | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 839ffab37f0d..c12717a73d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -25,6 +25,7 @@ import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema +import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.Try @@ -57,8 +58,6 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import scala.annotation.tailrec - class StaxXmlParser( schema: StructType, val options: XmlOptions) extends Logging { @@ -217,16 +216,16 @@ class StaxXmlParser( parser.next parser.peek match { case _: EndElement => - // TODO: optimize it - // TODO: array of value tag - if (!isEmptyString(c)) { - val indexOpt = getFieldNameToIndex(st).get(options.valueTag) - indexOpt.map { index => - // TODO: optimize it + // It couldn't be an array of value tags + // as the opening tag is immediately followed by a closing tag. + if (isEmptyString(c)) { + return null + } + val indexOpt = getFieldNameToIndex(st).get(options.valueTag) + indexOpt match { + case Some(index) => convertTo(c.getData, st.fields(index).dataType) - }.orNull - } else { - null + case None => null } case _ => val row = convertObject(parser, st) @@ -238,7 +237,6 @@ class StaxXmlParser( } case (_: Characters, _: StringType) => convertTo(StaxXmlParserUtils.currentStructureAsString(parser), StringType) - // TODO: can we remove it? case (c: Characters, _: DataType) if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. @@ -281,7 +279,7 @@ class StaxXmlParser( case c: Characters if !isEmptyString(c) => // Create a value tag field for it kvPairs += - // TODO: potential mismatch? + // TODO: We don't support an array value tags in map yet. (UTF8String.fromString(options.valueTag) -> convertTo(c.getData, valueType)) case _: EndElement => shouldStop = StaxXmlParserUtils.checkEndElement(parser) @@ -626,16 +624,16 @@ class StaxXmlParser( schema.getFieldIndex(name) match { case Some(index) => schema(index).dataType match { - case arr @ ArrayType(elementType, _) => + case ArrayType(elementType, _) => val value = convertTo(string, elementType) val result = if (row(index) == null) { ArrayBuffer(value) } else { - // TODO(shujing): optimization? + val genericArrayData = row(index).asInstanceOf[GenericArrayData] if (addToTail) { - row(index).asInstanceOf[GenericArrayData].toArray(elementType) :+ value + genericArrayData.toArray(elementType) :+ value } else { - value +: row(index).asInstanceOf[GenericArrayData].toArray(elementType) + value +: genericArrayData.toArray(elementType) } } row(index) = new GenericArrayData(result) @@ -644,7 +642,6 @@ class StaxXmlParser( } case None => // do nothing } - // TODO(shujing): optimization? InternalRow.fromSeq(row.toIndexedSeq) } } From 02193a8327dd1b177e9186db1d719477f5cccea9 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 12 Dec 2023 11:05:59 -0800 Subject: [PATCH 05/19] pkg --- .../org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index 51daba7ed26d..dec4a1cc3ce4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -22,11 +22,13 @@ import javax.xml.stream.XMLEventReader import javax.xml.stream.events._ import javax.xml.transform.stream.StreamSource import javax.xml.validation.Schema + import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.control.Exception._ import scala.util.control.NonFatal + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.ExprUtils From 0e79565ca3f7d28e2731e6217f0213486f7396b9 Mon Sep 17 00:00:00 2001 From: Shujing Yang <135740748+shujingyang-db@users.noreply.github.com> Date: Mon, 18 Dec 2023 21:46:52 -0800 Subject: [PATCH 06/19] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala Co-authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> --- .../scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index c12717a73d5c..92016ada3d2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -619,7 +619,7 @@ class StaxXmlParser( row: Array[Any], schema: StructType, name: String, - string: String, + data: String, addToTail: Boolean = true): InternalRow = { schema.getFieldIndex(name) match { case Some(index) => From f60a758c1901fc31963ce4b76ee25c95dde627f2 Mon Sep 17 00:00:00 2001 From: Shujing Yang <135740748+shujingyang-db@users.noreply.github.com> Date: Mon, 18 Dec 2023 22:18:42 -0800 Subject: [PATCH 07/19] Update sql/core/src/test/resources/test-data/xml-resources/values-simple.xml Co-authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com> --- .../src/test/resources/test-data/xml-resources/values-simple.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml b/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml index 469c38a25c36..a44a9b61e784 100644 --- a/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml +++ b/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml @@ -7,5 +7,6 @@ 1 value3 + value4 From 4f3acc0cea29edeab8a91aba0ad35844284c7660 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 22:25:41 -0800 Subject: [PATCH 08/19] whitespace --- .../apache/spark/sql/catalyst/xml/StaxXmlParser.scala | 4 ++-- .../spark/sql/catalyst/xml/XmlInferSchema.scala | 11 ++++------- .../sql/execution/datasources/xml/XmlSuite.scala | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index 92016ada3d2b..ac88672a5823 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -625,7 +625,7 @@ class StaxXmlParser( case Some(index) => schema(index).dataType match { case ArrayType(elementType, _) => - val value = convertTo(string, elementType) + val value = convertTo(data, elementType) val result = if (row(index) == null) { ArrayBuffer(value) } else { @@ -638,7 +638,7 @@ class StaxXmlParser( } row(index) = new GenericArrayData(result) case dataType => - row(index) = convertTo(string, dataType) + row(index) = convertTo(data, dataType) } case None => // do nothing } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index dec4a1cc3ce4..1b7a65312601 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -159,7 +159,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) parser.peek match { case _: EndElement => NullType case _: StartElement => inferObject(parser) - case c: Characters if isEmptyString(c) => + case c: Characters if c.isWhiteSpace => // When `Characters` is found, we need to look further to decide // if this is really data or space between other elements. val data = c.getData @@ -171,8 +171,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case _: EndElement => StringType case _ => inferField(parser) } - // what about new line character - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => // This could be the characters of a character-only element, or could have mixed // characters and other complex structure val characterType = inferFrom(c.getData) @@ -240,7 +239,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) parser.peek match { case _: EndElement | _: EndDocument => true case _: StartElement => false - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => val characterType = inferFrom(c.getData) parser.nextEvent() addOrUpdateType(options.valueTag, characterType) @@ -291,7 +290,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) addOrUpdateType(field, inferredType) - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) addOrUpdateType(options.valueTag, valueTagType) @@ -555,8 +554,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } - private[xml] def isEmptyString(c: Characters): Boolean = c.getData.trim.isEmpty - private def updateStructField( structType: StructType, index: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index acfb55ebd584..b685b72cd0f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2173,7 +2173,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("multiLine", "true") .load(getTestResourcePath(resDir + "values-simple.xml")) - checkAnswer(df, Seq(Row("value1", Row(Array("value2", "value3"), 1)))) + checkAnswer(df, Seq(Row(Array("value1", "value4"), Row(Array("value2", "value3"), 1)))) } test("capture values interspersed between elements - array") { From 3de09f46a3f682cea86c19dd2270a0586bb73820 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 22:26:48 -0800 Subject: [PATCH 09/19] whitespace --- .../spark/sql/catalyst/xml/StaxXmlParser.scala | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index ac88672a5823..343970736dc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -218,7 +218,7 @@ class StaxXmlParser( case _: EndElement => // It couldn't be an array of value tags // as the opening tag is immediately followed by a closing tag. - if (isEmptyString(c)) { + if (c.isWhiteSpace) { return null } val indexOpt = getFieldNameToIndex(st).get(options.valueTag) @@ -229,7 +229,7 @@ class StaxXmlParser( } case _ => val row = convertObject(parser, st) - if (!isEmptyString(c)) { + if (!c.isWhiteSpace) { addOrUpdate(row.toSeq(st).toArray, st, options.valueTag, c.getData, addToTail = false) } else { row @@ -276,7 +276,7 @@ class StaxXmlParser( kvPairs += (UTF8String.fromString(StaxXmlParserUtils.getName(e.asStartElement.getName, options)) -> convertField(parser, valueType)) - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => // Create a value tag field for it kvPairs += // TODO: We don't support an array value tags in map yet. @@ -424,7 +424,7 @@ class StaxXmlParser( badRecordException = badRecordException.orElse(Some(e)) } - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => addOrUpdate(row, schema, options.valueTag, c.getData) case _: EndElement => @@ -589,13 +589,6 @@ class StaxXmlParser( castTo(data, FloatType).asInstanceOf[Float] } } - private[xml] def isEmptyString(c: Characters): Boolean = { - if (options.ignoreSurroundingSpaces) { - c.getData.trim.isEmpty - } else { - c.isWhiteSpace - } - } @tailrec private def parseAndCheckEndElement( @@ -605,7 +598,7 @@ class StaxXmlParser( parser.peek match { case _: EndElement | _: EndDocument => true case _: StartElement => false - case c: Characters if !isEmptyString(c) => + case c: Characters if !c.isWhiteSpace => parser.nextEvent() addOrUpdate(row, schema, options.valueTag, c.getData) parseAndCheckEndElement(row, schema, parser) From 775052afd6b2ae7ccbd6c2d4b6f53479c06ee538 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 22:40:22 -0800 Subject: [PATCH 10/19] fix test case --- .../apache/spark/sql/catalyst/xml/XmlOptions.scala | 2 +- .../sql/execution/datasources/xml/XmlSuite.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 218d56c0f203..92b156fb8f23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -95,7 +95,7 @@ class XmlOptions( val nullValue = parameters.getOrElse(NULL_VALUE, XmlOptions.DEFAULT_NULL_VALUE) val columnNameOfCorruptRecord = parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord) - val ignoreSurroundingSpaces = getBool(IGNORE_SURROUNDING_SPACES, true) + val ignoreSurroundingSpaces = getBool(IGNORE_SURROUNDING_SPACES, false) val parseMode = ParseMode.fromString(parameters.getOrElse(MODE, PermissiveMode.name)) val inferSchema = getBool(INFER_SCHEMA, true) val rowValidationXSDPath = parameters.get(ROW_VALIDATION_XSD_PATH).orNull diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index b685b72cd0f9..72787dc826c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -731,7 +731,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .collect() assert(results(0) === Row("alice", "35")) - assert(results(1) === Row("bob", "")) + assert(results(1) === Row("bob", " ")) assert(results(2) === Row("coc", "24")) } @@ -817,7 +817,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(result(0) === Row(Row(null))) assert(result(1) === Row(Row(Row(null, null)))) assert(result(2) === Row(Row(Row("E", null)))) - assert(result(3) === Row(Row(Row("E", "")))) + assert(result(3) === Row(Row(Row("E", " ")))) assert(result(4) === Row(Row(Row("E", "")))) } @@ -1145,8 +1145,8 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children.xml")) val mixedRow = mixedDF.head() - assert(mixedRow.getAs[Row](0) === Row(List("issue", "text ignored"), "lorem")) - assert(mixedRow.getString(1) === "ipsum") + assert(mixedRow.getAs[Row](0) === Row(List(" issue ", " text ignored "), " lorem ")) + assert(mixedRow.getString(1) === " ipsum ") } test("test mixed text and complex element children") { @@ -1154,9 +1154,9 @@ class XmlSuite extends QueryTest with SharedSparkSession { .option("rowTag", "root") .option("inferSchema", true) .xml(getTestResourcePath(resDir + "mixed_children_2.xml")) - assert(mixedDF.select("foo.bar").head().getString(0) === "lorem") + assert(mixedDF.select("foo.bar").head().getString(0) === " lorem ") assert(mixedDF.select("foo.baz.bing").head().getLong(0) === 2) - assert(mixedDF.select("missing").head().getString(0) === "ipsum") + assert(mixedDF.select("missing").head().getString(0) === " ipsum ") } test("test XSD validation") { @@ -1720,7 +1720,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { assert(result(1).getAs[String]("_attr") == "attr1" && result(1).getAs[String]("_VALUE") == "value2") // comments aren't included in valueTag - assert(result(2).getAs[String]("_VALUE") == "value3") + assert(result(2).getAs[String]("_VALUE") == "\n value3\n ") } } From 306cbe655a755cd4bbf1947340717dc3fd5b367f Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 23:06:11 -0800 Subject: [PATCH 11/19] deeply nested --- .../xml-resources/values-deeply-nested.xml | 40 +++++++++++++++++++ .../execution/datasources/xml/XmlSuite.scala | 37 +++++++++++++++-- 2 files changed, 74 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml b/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml new file mode 100644 index 000000000000..950a1106234d --- /dev/null +++ b/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml @@ -0,0 +1,40 @@ + + + + value1 + + value2 + + value3 + + value4 + + value5 + 1 + value6 + 2 + value7 + + value8 + string + value9 + + value10 + + + 3 + value11 + 4 + + string + value12 + + value13 + 3 + value14 + + value15 + + value16 + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 72787dc826c8..1bdfdd5beabe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -21,16 +21,13 @@ import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.time.Instant import java.util.TimeZone - import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ - import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.io.compress.GzipCodec - import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoders, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.util._ @@ -44,6 +41,8 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import scala.collection.immutable.ArraySeq + class XmlSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -2211,4 +2210,36 @@ class XmlSuite extends QueryTest with SharedSparkSession { Array(1, 2), 3))))) } + + test("capture values interspersed between elements - deeply nested") { + val df = spark.read + .format("xml") + .option("ignoreSurroundingSpaces", true) + .option("rowTag", "ROW") + .option("multiLine", "true") + .load(getTestResourcePath(resDir + "values-deeply-nested.xml")) + + val expectedAns = Seq(Row( + ArraySeq("value1", "value16"), + Row( + ArraySeq("value2", "value15"), + Row( + ArraySeq("value3", "value10", "value13", "value14"), + Array( + Row( + ArraySeq("value4", "value8", "value9"), + "string", + Row(ArraySeq("value5", "value6", "value7"), ArraySeq(1, 2))), + Row( + ArraySeq("value12"), + "string", + Row(ArraySeq("value11"), ArraySeq(3, 4))), + ), + 3 + ) + ) + )) + + checkAnswer(df, expectedAns) + } } From 2b1fc93820862e2fb361bde2b5bdecedc2b585cc Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 23:21:37 -0800 Subject: [PATCH 12/19] inline xml --- .../test-data/xml-resources/values-array.xml | 18 --- .../xml-resources/values-deeply-nested.xml | 40 ------- .../test-data/xml-resources/values-nested.xml | 15 --- .../test-data/xml-resources/values-simple.xml | 12 -- .../execution/datasources/xml/XmlSuite.scala | 113 +++++++++++++++--- 5 files changed, 99 insertions(+), 99 deletions(-) delete mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-array.xml delete mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml delete mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-nested.xml delete mode 100644 sql/core/src/test/resources/test-data/xml-resources/values-simple.xml diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-array.xml b/sql/core/src/test/resources/test-data/xml-resources/values-array.xml deleted file mode 100644 index d53d2db31538..000000000000 --- a/sql/core/src/test/resources/test-data/xml-resources/values-array.xml +++ /dev/null @@ -1,18 +0,0 @@ - - - - value1 - - value2 - 1 - value3 - - - value4 - 2 - value5 - 3 - value6 - - - diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml b/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml deleted file mode 100644 index 950a1106234d..000000000000 --- a/sql/core/src/test/resources/test-data/xml-resources/values-deeply-nested.xml +++ /dev/null @@ -1,40 +0,0 @@ - - - - value1 - - value2 - - value3 - - value4 - - value5 - 1 - value6 - 2 - value7 - - value8 - string - value9 - - value10 - - - 3 - value11 - 4 - - string - value12 - - value13 - 3 - value14 - - value15 - - value16 - - diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml b/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml deleted file mode 100644 index f5d70daa76d9..000000000000 --- a/sql/core/src/test/resources/test-data/xml-resources/values-nested.xml +++ /dev/null @@ -1,15 +0,0 @@ - - - - - - 1 - value1 - 2 - value2 - 3 - - value4 - - - diff --git a/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml b/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml deleted file mode 100644 index a44a9b61e784..000000000000 --- a/sql/core/src/test/resources/test-data/xml-resources/values-simple.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - value1 - - value2 - 1 - value3 - - value4 - - diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 1bdfdd5beabe..eabf9cea6bf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -21,13 +21,16 @@ import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.time.Instant import java.util.TimeZone + import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ + import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.io.compress.GzipCodec + import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoders, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.util._ @@ -2167,15 +2170,47 @@ class XmlSuite extends QueryTest with SharedSparkSession { } test("capture values interspersed between elements - simple") { - val df = spark.read.format("xml") + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | value4 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read .option("rowTag", "ROW") .option("multiLine", "true") - .load(getTestResourcePath(resDir + "values-simple.xml")) + .xml(input) checkAnswer(df, Seq(Row(Array("value1", "value4"), Row(Array("value2", "value3"), 1)))) } test("capture values interspersed between elements - array") { + val xmlString = + s""" + | + | value1 + | + | value2 + | 1 + | value3 + | + | + | value4 + | 2 + | value5 + | 3 + | value6 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) val expectedAns = Seq( Row( "value1", @@ -2183,21 +2218,35 @@ class XmlSuite extends QueryTest with SharedSparkSession { Row(List("value2", "value3"), 1, null), Row(List("value4", "value5", "value6"), 2, 3)))) val df = spark.read - .format("xml") .option("rowTag", "ROW") .option("multiLine", "true") - .load(getTestResourcePath(resDir + "values-array.xml")) + .xml(input) checkAnswer(df, expectedAns) } test("capture values interspersed between elements - nested struct") { + val xmlString = + s""" + | + | + | + | 1 + | value1 + | 2 + | value2 + | 3 + | + | value4 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) val df = spark.read - .format("xml") .option("rowTag", "ROW") .option("multiLine", "true") - .load(getTestResourcePath(resDir + "values-nested.xml")) + .xml(input) checkAnswer( df, @@ -2212,12 +2261,52 @@ class XmlSuite extends QueryTest with SharedSparkSession { } test("capture values interspersed between elements - deeply nested") { + val xmlString = + s""" + | + | value1 + | + | value2 + | + | value3 + | + | value4 + | + | value5 + | 1 + | value6 + | 2 + | value7 + | + | value8 + | string + | value9 + | + | value10 + | + | + | 3 + | value11 + | 4 + | + | string + | value12 + | + | value13 + | 3 + | value14 + | + | value15 + | + | value16 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) val df = spark.read - .format("xml") .option("ignoreSurroundingSpaces", true) .option("rowTag", "ROW") .option("multiLine", "true") - .load(getTestResourcePath(resDir + "values-deeply-nested.xml")) + .xml(input) val expectedAns = Seq(Row( ArraySeq("value1", "value16"), @@ -2233,12 +2322,8 @@ class XmlSuite extends QueryTest with SharedSparkSession { Row( ArraySeq("value12"), "string", - Row(ArraySeq("value11"), ArraySeq(3, 4))), - ), - 3 - ) - ) - )) + Row(ArraySeq("value11"), ArraySeq(3, 4)))), + 3)))) checkAnswer(df, expectedAns) } From bc89b57f4374193069070ddec87732aa74178572 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Mon, 18 Dec 2023 23:22:44 -0800 Subject: [PATCH 13/19] tailrec --- .../org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 110514718187..07011c0c253f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -20,7 +20,7 @@ import java.io.StringReader import javax.xml.namespace.QName import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants} import javax.xml.stream.events._ - +import scala.annotation.tailrec import scala.jdk.CollectionConverters._ object StaxXmlParserUtils { @@ -69,6 +69,7 @@ object StaxXmlParserUtils { /** * Checks if current event points the EndElement. */ + @tailrec def checkEndElement(parser: XMLEventReader): Boolean = { parser.peek match { case _: EndElement | _: EndDocument => true From 6599147d4e41a323308c3dd1f6eecd611de7b8b3 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 19 Dec 2023 10:15:52 -0800 Subject: [PATCH 14/19] nit --- .../org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala | 1 + .../apache/spark/sql/execution/datasources/xml/XmlSuite.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala index 07011c0c253f..0471cb310d89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParserUtils.scala @@ -20,6 +20,7 @@ import java.io.StringReader import javax.xml.namespace.QName import javax.xml.stream.{EventFilter, XMLEventReader, XMLInputFactory, XMLStreamConstants} import javax.xml.stream.events._ + import scala.annotation.tailrec import scala.jdk.CollectionConverters._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index e5fc3b0a7a18..2118ad10f8c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -22,8 +22,8 @@ import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDateTime} import java.util.TimeZone -import scala.collection.mutable import scala.collection.immutable.ArraySeq +import scala.collection.mutable import scala.io.Source import scala.jdk.CollectionConverters._ From 0fa042ddb403a004241f7048d6c181ebc6f4fa53 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Tue, 19 Dec 2023 13:24:50 -0800 Subject: [PATCH 15/19] ignoreSurroundingSpaces --- .../apache/spark/sql/execution/datasources/xml/XmlSuite.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 2118ad10f8c9..9e3e4bc3e6c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2377,6 +2377,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val input = spark.createDataset(Seq(xmlString)) val df = spark.read .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) .option("multiLine", "true") .xml(input) @@ -2411,6 +2412,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { Row(List("value4", "value5", "value6"), 2, 3)))) val df = spark.read .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) .option("multiLine", "true") .xml(input) @@ -2437,6 +2439,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { val input = spark.createDataset(Seq(xmlString)) val df = spark.read .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) .option("multiLine", "true") .xml(input) From dcae96277cb149ef41f86bc7043f53946b44e520 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 27 Dec 2023 11:13:08 -0800 Subject: [PATCH 16/19] test --- .../sql/catalyst/xml/XmlInferSchema.scala | 1 - .../execution/datasources/xml/XmlSuite.scala | 23 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index a633f42eeadf..d47fd90c62e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -560,7 +560,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) valueTagType: DataType): DataType = { (objectType, valueTagType) match { case (st: StructType, _) => - // TODO(shujing): case sensitive? val valueTagIndexOpt = st.getFieldIndex(options.valueTag) valueTagIndexOpt match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 9e3e4bc3e6c8..5fce4532ea52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2420,6 +2420,29 @@ class XmlSuite extends QueryTest with SharedSparkSession { } + test("capture values interspersed between elements - long and double") { + val xmlString = + s""" + | + | + | 1 + | 2 + | 3 + | 4 + | 5.0 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1.0, 3.0, 5.0), Array(2, 4))))) + } + test("capture values interspersed between elements - nested struct") { val xmlString = s""" From 0eb8aeb24b3bea6db826e8e0d4a762683eb32760 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 27 Dec 2023 17:35:32 -0800 Subject: [PATCH 17/19] comments --- .../sql/catalyst/xml/XmlInferSchema.scala | 58 +++++++++++-------- .../execution/datasources/xml/XmlSuite.scala | 39 +++++++++++++ 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala index d47fd90c62e7..9d0c16d95e46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala @@ -164,7 +164,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } } - @tailrec private def inferField(parser: XMLEventReader): DataType = { parser.peek match { case _: EndElement => NullType @@ -182,8 +181,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case _ => inferField(parser) } case c: Characters if !c.isWhiteSpace => - // This could be the characters of a character-only element, or could have mixed - // characters and other complex structure val characterType = inferFrom(c.getData) parser.nextEvent() parser.peek match { @@ -193,8 +190,16 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val innerType = inferObject(parser).asInstanceOf[StructType] addOrUpdateValueTagType(innerType, characterType) case _ => - // That's all, just the character-only body; use that as the type - characterType + val fieldType = inferField(parser) + fieldType match { + case st: StructType => addOrUpdateValueTagType(st, characterType) + case _: NullType => characterType + case _: DataType => + // The field type couldn't be an array type + new StructType() + .add(options.valueTag, addOrUpdateType(Some(characterType), fieldType)) + + } } case e: XMLEvent => throw new IllegalArgumentException(s"Failed to parse data with unexpected event $e") @@ -230,20 +235,6 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) val nameToDataType = collection.mutable.TreeMap.empty[String, DataType](caseSensitivityOrdering) - def addOrUpdateType(fieldName: String, newType: DataType): Unit = { - val oldTypeOpt = nameToDataType.get(fieldName) - oldTypeOpt match { - // If the field name exists in the map, - // merge the type and infer the combined field as an array type if necessary - case Some(oldType) if !oldType.isInstanceOf[ArrayType] => - nameToDataType.update(fieldName, ArrayType(compatibleType(oldType, newType))) - case Some(oldType) => - nameToDataType.update(fieldName, compatibleType(oldType, newType)) - case None => - nameToDataType.put(fieldName, newType) - } - } - @tailrec def inferAndCheckEndElement(parser: XMLEventReader): Boolean = { parser.peek match { @@ -252,7 +243,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) case c: Characters if !c.isWhiteSpace => val characterType = inferFrom(c.getData) parser.nextEvent() - addOrUpdateType(options.valueTag, characterType) + addOrUpdateType(nameToDataType, options.valueTag, characterType) inferAndCheckEndElement(parser) case _ => parser.nextEvent() @@ -265,7 +256,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) StaxXmlParserUtils.convertAttributesToValuesMap(rootAttributes, options) rootValuesMap.foreach { case (f, v) => - addOrUpdateType(f, inferFrom(v)) + addOrUpdateType(nameToDataType, f, inferFrom(v)) } var shouldStop = false while (!shouldStop) { @@ -298,12 +289,12 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) } // Add the field and datatypes so that we can check if this is ArrayType. val field = StaxXmlParserUtils.getName(e.asStartElement.getName, options) - addOrUpdateType(field, inferredType) + addOrUpdateType(nameToDataType, field, inferredType) case c: Characters if !c.isWhiteSpace => // This can be an attribute-only object val valueTagType = inferFrom(c.getData) - addOrUpdateType(options.valueTag, valueTagType) + addOrUpdateType(nameToDataType, options.valueTag, valueTagType) case _: EndElement => shouldStop = inferAndCheckEndElement(parser) @@ -591,4 +582,25 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean) StructType(newFields) } + private def addOrUpdateType( + nameToDataType: collection.mutable.TreeMap[String, DataType], + fieldName: String, + newType: DataType): Unit = { + val oldTypeOpt = nameToDataType.get(fieldName) + val mergedType = addOrUpdateType(oldTypeOpt, newType) + nameToDataType.put(fieldName, mergedType) + } + + private def addOrUpdateType(oldTypeOpt: Option[DataType], newType: DataType): DataType = { + oldTypeOpt match { + // If the field name already exists, + // merge the type and infer the combined field as an array type if necessary + case Some(oldType) if !oldType.isInstanceOf[ArrayType] && !newType.isInstanceOf[NullType] => + ArrayType(compatibleType(oldType, newType)) + case Some(oldType) => + compatibleType(oldType, newType) + case None => + newType + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 5fce4532ea52..afc5c49d7a8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2443,6 +2443,45 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Seq(Row(Row(Array(1.0, 3.0, 5.0), Array(2, 4))))) } + test("capture values interspersed between elements - comments") { + val xmlString = + s""" + | + | 1 2 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2))))) + } + + test("capture values interspersed between elements - nested comments") { + val xmlString = + s""" + | + | 1 + | 2 + | 1 + | 3 + | 2 + | + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", true) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq(Row(Row(Array(1, 2, 3), Array(1, 2))))) + } + test("capture values interspersed between elements - nested struct") { val xmlString = s""" From a5c3fbc5140a990bf307b963bbfa20bec741ab78 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 27 Dec 2023 18:08:40 -0800 Subject: [PATCH 18/19] whitespace with quotes --- .../execution/datasources/xml/XmlSuite.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index afc5c49d7a8a..fa4814836e47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2460,6 +2460,28 @@ class XmlSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Seq(Row(Row(Array(1, 2))))) } + test("capture values interspersed between elements - whitespaces with quotes") { + val xmlString = + s""" + | + | " " + | " "1 + | + | + |" "1 + | + |""".stripMargin + val input = spark.createDataset(Seq(xmlString)) + val df = spark.read + .option("rowTag", "ROW") + .option("ignoreSurroundingSpaces", false) + .option("multiLine", "true") + .xml(input) + + checkAnswer(df, Seq( + Row(Row(" "), Row(Row(1), " "), Row(null, "")))) + } + test("capture values interspersed between elements - nested comments") { val xmlString = s""" From 32bd9feb892f7b2b0130e8cd9dacb82359251ab9 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Thu, 28 Dec 2023 10:19:27 -0800 Subject: [PATCH 19/19] fix whitespace --- .../apache/spark/sql/execution/datasources/xml/XmlSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index fa4814836e47..4b9a95856afb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -2468,8 +2468,6 @@ class XmlSuite extends QueryTest with SharedSparkSession { | " "1 | | - |" "1 - | |""".stripMargin val input = spark.createDataset(Seq(xmlString)) val df = spark.read @@ -2479,7 +2477,7 @@ class XmlSuite extends QueryTest with SharedSparkSession { .xml(input) checkAnswer(df, Seq( - Row(Row(" "), Row(Row(1), " "), Row(null, "")))) + Row("\" \"", Row(1, "\" \""), Row(Row(null, " "))))) } test("capture values interspersed between elements - nested comments") {