diff --git a/docs/sql-data-sources-xml.md b/docs/sql-data-sources-xml.md index b10e054634ed..3b735191fc42 100644 --- a/docs/sql-data-sources-xml.md +++ b/docs/sql-data-sources-xml.md @@ -94,7 +94,7 @@ Data source options of XML can be set via: inferSchema true - If true, attempts to infer an appropriate type for each resulting DataFrame column. If false, all resulting columns are of string type. Default is true. XML built-in functions ignore this option. + If true, attempts to infer an appropriate type for each resulting DataFrame column. If false, all resulting columns are of string type. read @@ -108,7 +108,7 @@ Data source options of XML can be set via: attributePrefix _ - The prefix for attributes to differentiate attributes from elements. This will be the prefix for field names. Default is _. Can be empty for reading XML, but not for writing. + The prefix for attributes to differentiate attributes from elements. This will be the prefix for field names. Can be empty for reading XML, but not for writing. read/write @@ -235,5 +235,12 @@ Data source options of XML can be set via: write + + validateName + true + If true, throws error on XML element name validation failure. For example, SQL field names can have spaces, but XML element names cannot. + write + + Other generic options can be found in Generic File Source Options. diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 52975917ea02..51698f262fc5 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -792,6 +792,7 @@ def xml( timestampFormat: Optional[str] = None, compression: Optional[str] = None, encoding: Optional[str] = None, + validateName: Optional[bool] = None, ) -> None: self.mode(mode) self._set_opts( @@ -806,6 +807,7 @@ def xml( timestampFormat=timestampFormat, compression=compression, encoding=encoding, + validateName=validateName, ) self.format("xml").save(path) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b61284247b0e..db9220fc48bb 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -2096,6 +2096,7 @@ def xml( timestampFormat: Optional[str] = None, compression: Optional[str] = None, encoding: Optional[str] = None, + validateName: Optional[bool] = None, ) -> None: r"""Saves the content of the :class:`DataFrame` in XML format at the specified path. @@ -2155,6 +2156,7 @@ def xml( timestampFormat=timestampFormat, compression=compression, encoding=encoding, + validateName=validateName, ) self._jwrite.xml(path) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala index 43e89c49a89e..53c8b4cf3422 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala @@ -65,6 +65,7 @@ class StaxXmlGenerator( val factory = XMLOutputFactory.newInstance() // to_xml disables structure validation to allow multiple root tags factory.setProperty(WstxOutputProperties.P_OUTPUT_VALIDATE_STRUCTURE, validateStructure) + factory.setProperty(WstxOutputProperties.P_OUTPUT_VALIDATE_NAMES, options.validateName) val xmlWriter = factory.createXMLStreamWriter(writer) if (!indentDisabled) { val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala index 218d56c0f203..336c54e164e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala @@ -107,6 +107,7 @@ class XmlOptions( // setting indent to "" disables indentation in the generated XML. // Each row will be written in a new line. val indent = parameters.getOrElse(INDENT, DEFAULT_INDENT) + val validateName = getBool(VALIDATE_NAME, true) /** * Infer columns with all valid date entries as date type (otherwise inferred as string or @@ -210,6 +211,7 @@ object XmlOptions extends DataSourceOptions { val TIME_ZONE = newOption("timeZone") val INDENT = newOption("indent") val PREFERS_DECIMAL = newOption("prefersDecimal") + val VALIDATE_NAME = newOption("validateName") // Options with alternative val ENCODING = "encoding" val CHARSET = "charset" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 398706dba3d9..b3d10d2115f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -16,12 +16,13 @@ */ package org.apache.spark.sql.execution.datasources.xml -import java.io.EOFException +import java.io.{EOFException, File} import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} import java.nio.file.{Files, Path, Paths} import java.sql.{Date, Timestamp} import java.time.{Instant, LocalDateTime} import java.util.TimeZone +import javax.xml.stream.XMLStreamException import scala.collection.immutable.ArraySeq import scala.collection.mutable @@ -2828,4 +2829,55 @@ class XmlSuite } } } + + test("XML Validate Name") { + val data = Seq(Row("Random String")) + + def checkValidation(fieldName: String, + errorMsg: String, + validateName: Boolean = true): Unit = { + val schema = StructType(Seq(StructField(fieldName, StringType))) + val df = spark.createDataFrame(data.asJava, schema) + + withTempDir { dir => + val path = dir.getCanonicalPath + validateName match { + case false => + df.write + .option("rowTag", "ROW") + .option("validateName", false) + .option("declaration", "") + .option("indent", "") + .mode(SaveMode.Overwrite) + .xml(path) + // read file back and check its content + val xmlFile = new File(path).listFiles() + .filter(_.isFile) + .filter(_.getName.endsWith("xml")).head + val actualContent = Files.readString(xmlFile.toPath).replaceAll("\\n", "") + assert(actualContent === + s"<${XmlOptions.DEFAULT_ROOT_TAG}>" + + s"<$fieldName>${data.head.getString(0)}" + + s"") + + case true => + val e = intercept[SparkException] { + df.write + .option("rowTag", "ROW") + .mode(SaveMode.Overwrite) + .xml(path) + } + + assert(e.getCause.getCause.isInstanceOf[XMLStreamException]) + assert(e.getMessage.contains(errorMsg)) + } + } + } + + checkValidation("", "Illegal to pass empty name") + checkValidation(" ", "Illegal first name character ' '") + checkValidation("1field", "Illegal first name character '1'") + checkValidation("field name with space", "Illegal name character ' '") + checkValidation("field", "", false) + } }