diff --git a/docs/sql-data-sources-xml.md b/docs/sql-data-sources-xml.md
index b10e054634ed..3b735191fc42 100644
--- a/docs/sql-data-sources-xml.md
+++ b/docs/sql-data-sources-xml.md
@@ -94,7 +94,7 @@ Data source options of XML can be set via:
inferSchema |
true |
- If true, attempts to infer an appropriate type for each resulting DataFrame column. If false, all resulting columns are of string type. Default is true. XML built-in functions ignore this option. |
+ If true, attempts to infer an appropriate type for each resulting DataFrame column. If false, all resulting columns are of string type. |
read |
@@ -108,7 +108,7 @@ Data source options of XML can be set via:
attributePrefix |
_ |
- The prefix for attributes to differentiate attributes from elements. This will be the prefix for field names. Default is _. Can be empty for reading XML, but not for writing. |
+ The prefix for attributes to differentiate attributes from elements. This will be the prefix for field names. Can be empty for reading XML, but not for writing. |
read/write |
@@ -235,5 +235,12 @@ Data source options of XML can be set via:
write |
+
+ validateName |
+ true |
+ If true, throws error on XML element name validation failure. For example, SQL field names can have spaces, but XML element names cannot. |
+ write |
+
+
Other generic options can be found in Generic File Source Options.
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 52975917ea02..51698f262fc5 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -792,6 +792,7 @@ def xml(
timestampFormat: Optional[str] = None,
compression: Optional[str] = None,
encoding: Optional[str] = None,
+ validateName: Optional[bool] = None,
) -> None:
self.mode(mode)
self._set_opts(
@@ -806,6 +807,7 @@ def xml(
timestampFormat=timestampFormat,
compression=compression,
encoding=encoding,
+ validateName=validateName,
)
self.format("xml").save(path)
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b61284247b0e..db9220fc48bb 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -2096,6 +2096,7 @@ def xml(
timestampFormat: Optional[str] = None,
compression: Optional[str] = None,
encoding: Optional[str] = None,
+ validateName: Optional[bool] = None,
) -> None:
r"""Saves the content of the :class:`DataFrame` in XML format at the specified path.
@@ -2155,6 +2156,7 @@ def xml(
timestampFormat=timestampFormat,
compression=compression,
encoding=encoding,
+ validateName=validateName,
)
self._jwrite.xml(path)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
index 43e89c49a89e..53c8b4cf3422 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlGenerator.scala
@@ -65,6 +65,7 @@ class StaxXmlGenerator(
val factory = XMLOutputFactory.newInstance()
// to_xml disables structure validation to allow multiple root tags
factory.setProperty(WstxOutputProperties.P_OUTPUT_VALIDATE_STRUCTURE, validateStructure)
+ factory.setProperty(WstxOutputProperties.P_OUTPUT_VALIDATE_NAMES, options.validateName)
val xmlWriter = factory.createXMLStreamWriter(writer)
if (!indentDisabled) {
val indentingXmlWriter = new IndentingXMLStreamWriter(xmlWriter)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index 218d56c0f203..336c54e164e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -107,6 +107,7 @@ class XmlOptions(
// setting indent to "" disables indentation in the generated XML.
// Each row will be written in a new line.
val indent = parameters.getOrElse(INDENT, DEFAULT_INDENT)
+ val validateName = getBool(VALIDATE_NAME, true)
/**
* Infer columns with all valid date entries as date type (otherwise inferred as string or
@@ -210,6 +211,7 @@ object XmlOptions extends DataSourceOptions {
val TIME_ZONE = newOption("timeZone")
val INDENT = newOption("indent")
val PREFERS_DECIMAL = newOption("prefersDecimal")
+ val VALIDATE_NAME = newOption("validateName")
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index 398706dba3d9..b3d10d2115f0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -16,12 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.xml
-import java.io.EOFException
+import java.io.{EOFException, File}
import java.nio.charset.{StandardCharsets, UnsupportedCharsetException}
import java.nio.file.{Files, Path, Paths}
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDateTime}
import java.util.TimeZone
+import javax.xml.stream.XMLStreamException
import scala.collection.immutable.ArraySeq
import scala.collection.mutable
@@ -2828,4 +2829,55 @@ class XmlSuite
}
}
}
+
+ test("XML Validate Name") {
+ val data = Seq(Row("Random String"))
+
+ def checkValidation(fieldName: String,
+ errorMsg: String,
+ validateName: Boolean = true): Unit = {
+ val schema = StructType(Seq(StructField(fieldName, StringType)))
+ val df = spark.createDataFrame(data.asJava, schema)
+
+ withTempDir { dir =>
+ val path = dir.getCanonicalPath
+ validateName match {
+ case false =>
+ df.write
+ .option("rowTag", "ROW")
+ .option("validateName", false)
+ .option("declaration", "")
+ .option("indent", "")
+ .mode(SaveMode.Overwrite)
+ .xml(path)
+ // read file back and check its content
+ val xmlFile = new File(path).listFiles()
+ .filter(_.isFile)
+ .filter(_.getName.endsWith("xml")).head
+ val actualContent = Files.readString(xmlFile.toPath).replaceAll("\\n", "")
+ assert(actualContent ===
+ s"<${XmlOptions.DEFAULT_ROOT_TAG}>" +
+ s"<$fieldName>${data.head.getString(0)}$fieldName>" +
+ s"
${XmlOptions.DEFAULT_ROOT_TAG}>")
+
+ case true =>
+ val e = intercept[SparkException] {
+ df.write
+ .option("rowTag", "ROW")
+ .mode(SaveMode.Overwrite)
+ .xml(path)
+ }
+
+ assert(e.getCause.getCause.isInstanceOf[XMLStreamException])
+ assert(e.getMessage.contains(errorMsg))
+ }
+ }
+ }
+
+ checkValidation("", "Illegal to pass empty name")
+ checkValidation(" ", "Illegal first name character ' '")
+ checkValidation("1field", "Illegal first name character '1'")
+ checkValidation("field name with space", "Illegal name character ' '")
+ checkValidation("field", "", false)
+ }
}