Skip to content

Commit

Permalink
fix: use defusedxml whenever we load XML to prevent XEE attacks
Browse files Browse the repository at this point in the history
  • Loading branch information
madpah authored Mar 3, 2023
2 parents 90de3b8 + 32fd5a6 commit ae3d76c
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 43 deletions.
30 changes: 21 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ keywords = [

[tool.poetry.dependencies]
python = "^3.7"
defusedxml = "^0.7.1"

[tool.poetry.dev-dependencies]
autopep8 = "^1.6.0"
Expand Down
47 changes: 25 additions & 22 deletions serializable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from json import JSONEncoder
from sys import version_info
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union, cast
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement

from defusedxml import ElementTree as SafeElementTree # type: ignore

if version_info >= (3, 8):
from typing import Protocol
Expand Down Expand Up @@ -314,7 +316,7 @@ def _from_json(cls: Type[_T], data: Dict[str, Any]) -> object:


def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True, element_name: Optional[str] = None,
xmlns: Optional[str] = None) -> Union[ElementTree.Element, str]:
xmlns: Optional[str] = None) -> Union[Element, str]:
logging.debug(f'Dumping {self} to XML with view {view_}...')

this_e_attributes = {}
Expand Down Expand Up @@ -352,7 +354,7 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
element_name = _namespace_element_name(tag_name=element_name,
xmlns=xmlns) if element_name else _namespace_element_name(
tag_name=CurrentFormatter.formatter.encode(self.__class__.__name__), xmlns=xmlns)
this_e = ElementTree.Element(element_name, this_e_attributes)
this_e = Element(element_name, this_e_attributes)

# Handle remaining Properties that will be sub elements
for k, prop_info in serializable_property_info.items():
Expand All @@ -375,7 +377,7 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
new_key = prop_info.custom_names.get(SerializationType.XML, new_key)

if v is None:
ElementTree.SubElement(this_e, _namespace_element_name(tag_name=new_key, xmlns=xmlns))
SubElement(this_e, _namespace_element_name(tag_name=new_key, xmlns=xmlns))
continue

if new_key == '.':
Expand All @@ -390,28 +392,28 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
_array_type, nested_key = prop_info.xml_array_config
nested_key = _namespace_element_name(tag_name=nested_key, xmlns=xmlns)
if _array_type and _array_type == XmlArraySerializationType.NESTED:
nested_e = ElementTree.SubElement(this_e, new_key)
nested_e = SubElement(this_e, new_key)
else:
nested_e = this_e
for j in v:
if not prop_info.is_primitive_type() and not prop_info.is_enum:
nested_e.append(j.as_xml(view_=view_, as_string=False, element_name=nested_key, xmlns=xmlns))
elif prop_info.is_enum:
ElementTree.SubElement(nested_e, nested_key).text = str(j.value)
SubElement(nested_e, nested_key).text = str(j.value)
elif prop_info.concrete_type in (float, int):
ElementTree.SubElement(nested_e, nested_key).text = str(j)
SubElement(nested_e, nested_key).text = str(j)
elif prop_info.concrete_type is bool:
ElementTree.SubElement(nested_e, nested_key).text = str(j).lower()
SubElement(nested_e, nested_key).text = str(j).lower()
else:
# Assume type is str
ElementTree.SubElement(nested_e, nested_key).text = str(j)
SubElement(nested_e, nested_key).text = str(j)
elif prop_info.custom_type:
if prop_info.is_helper_type():
ElementTree.SubElement(this_e, new_key).text = str(prop_info.custom_type.serialize(v))
SubElement(this_e, new_key).text = str(prop_info.custom_type.serialize(v))
else:
ElementTree.SubElement(this_e, new_key).text = str(prop_info.custom_type(v))
SubElement(this_e, new_key).text = str(prop_info.custom_type(v))
elif prop_info.is_enum:
ElementTree.SubElement(this_e, new_key).text = str(v.value)
SubElement(this_e, new_key).text = str(v.value)
elif not prop_info.is_primitive_type():
global_klass_name = f'{prop_info.concrete_type.__module__}.{prop_info.concrete_type.__name__}'
if global_klass_name in ObjectMetadataLibrary.klass_mappings:
Expand All @@ -420,24 +422,24 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
else:
# Handle properties that have a type that is not a Python Primitive (e.g. int, float, str)
if prop_info.string_format:
ElementTree.SubElement(this_e, new_key).text = f'{v:{prop_info.string_format}}'
SubElement(this_e, new_key).text = f'{v:{prop_info.string_format}}'
else:
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)
elif prop_info.concrete_type in (float, int):
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)
elif prop_info.concrete_type is bool:
ElementTree.SubElement(this_e, new_key).text = str(v).lower()
SubElement(this_e, new_key).text = str(v).lower()
else:
# Assume type is str
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)

if as_string:
return ElementTree.tostring(this_e, 'unicode')
return cast(Element, SafeElementTree.tostring(this_e, 'unicode'))
else:
return this_e


def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, ElementTree.Element],
def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, Element],
default_namespace: Optional[str] = None) -> object:
logging.debug(f'Rendering XML from {type(data)} to {cls}...')
klass = ObjectMetadataLibrary.klass_mappings.get(f'{cls.__module__}.{cls.__qualname__}', None)
Expand All @@ -448,11 +450,12 @@ def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, ElementTree.Element],
klass_properties = ObjectMetadataLibrary.klass_property_mappings.get(f'{cls.__module__}.{cls.__qualname__}', {})

if isinstance(data, TextIOWrapper):
data = ElementTree.fromstring(data.read())
data = cast(Element, SafeElementTree.fromstring(data.read()))

if default_namespace is None:
_namespaces = dict([node for _, node in ElementTree.iterparse(StringIO(ElementTree.tostring(data, 'unicode')),
events=['start-ns'])])
_namespaces = dict([node for _, node in
SafeElementTree.iterparse(StringIO(SafeElementTree.tostring(data, 'unicode')),
events=['start-ns'])])
if 'ns0' in _namespaces:
default_namespace = _namespaces['ns0']
else:
Expand Down
10 changes: 5 additions & 5 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import json
import os
import xml
from typing import Any
from unittest import TestCase

import lxml # type: ignore
from defusedxml import ElementTree as SafeElementTree # type: ignore
from xmldiff import main # type: ignore
from xmldiff.actions import MoveNode # type: ignore

Expand All @@ -48,12 +48,12 @@ def assertEqualJson(self, a: str, b: str) -> None:
)

def assertEqualXml(self, a: str, b: str) -> None:
a = xml.etree.ElementTree.tostring(
xml.etree.ElementTree.fromstring(a, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
a = SafeElementTree.tostring(
SafeElementTree.fromstring(a, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
'unicode'
)
b = xml.etree.ElementTree.tostring(
xml.etree.ElementTree.fromstring(b, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
b = SafeElementTree.tostring(
SafeElementTree.fromstring(b, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
'unicode'
)
diff_results = main.diff_texts(a, b, diff_options={'F': 0.5})
Expand Down
15 changes: 8 additions & 7 deletions tests/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import logging
import os
from xml.etree import ElementTree

from defusedxml import ElementTree as SafeElementTree

from serializable.formatters import (
CamelCasePropertyNameFormatter,
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_serialize_tfp_sc1(self) -> None:
def test_deserialize_tfp_cc1(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v1.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -87,7 +88,7 @@ def test_deserialize_tfp_cc1(self) -> None:
def test_deserialize_tfp_cc1_v2(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v2.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject.title, book.title)
self.assertEqual(ThePhoenixProject.isbn, book.isbn)
self.assertEqual(ThePhoenixProject.edition, book.edition)
Expand All @@ -101,7 +102,7 @@ def test_deserialize_tfp_cc1_v2(self) -> None:
def test_deserialize_tfp_cc1_v3(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v3.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -115,7 +116,7 @@ def test_deserialize_tfp_cc1_v3(self) -> None:
def test_deserialize_tfp_cc1_v4(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v4.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject.title, book.title)
self.assertEqual(ThePhoenixProject.isbn, book.isbn)
self.assertEqual(ThePhoenixProject.edition, book.edition)
Expand All @@ -129,7 +130,7 @@ def test_deserialize_tfp_cc1_v4(self) -> None:
def test_deserialize_tfp_cc1_with_ignored(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-with-ignored.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -153,7 +154,7 @@ def test_deserialize_tfp_kc1(self) -> None:
def test_deserialize_tfp_sc1(self) -> None:
CurrentFormatter.formatter = SnakeCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-snake-case-1.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand Down

0 comments on commit ae3d76c

Please sign in to comment.