Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use defusedxml whenever we load XML to prevent XEE attacks #5

Merged
merged 4 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ keywords = [

[tool.poetry.dependencies]
python = "^3.7"
defusedxml = "^0.7.1"

[tool.poetry.dev-dependencies]
autopep8 = "^1.6.0"
Expand Down
47 changes: 25 additions & 22 deletions serializable/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
from json import JSONEncoder
from sys import version_info
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, TypeVar, Union, cast
from xml.etree import ElementTree
from xml.etree.ElementTree import Element, SubElement
madpah marked this conversation as resolved.
Show resolved Hide resolved
madpah marked this conversation as resolved.
Show resolved Hide resolved
madpah marked this conversation as resolved.
Show resolved Hide resolved

from defusedxml import ElementTree as SafeElementTree # type: ignore

if version_info >= (3, 8):
from typing import Protocol
Expand Down Expand Up @@ -314,7 +316,7 @@ def _from_json(cls: Type[_T], data: Dict[str, Any]) -> object:


def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True, element_name: Optional[str] = None,
xmlns: Optional[str] = None) -> Union[ElementTree.Element, str]:
xmlns: Optional[str] = None) -> Union[Element, str]:
logging.debug(f'Dumping {self} to XML with view {view_}...')

this_e_attributes = {}
Expand Down Expand Up @@ -352,7 +354,7 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
element_name = _namespace_element_name(tag_name=element_name,
xmlns=xmlns) if element_name else _namespace_element_name(
tag_name=CurrentFormatter.formatter.encode(self.__class__.__name__), xmlns=xmlns)
this_e = ElementTree.Element(element_name, this_e_attributes)
this_e = Element(element_name, this_e_attributes)

# Handle remaining Properties that will be sub elements
for k, prop_info in serializable_property_info.items():
Expand All @@ -375,7 +377,7 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
new_key = prop_info.custom_names.get(SerializationType.XML, new_key)

if v is None:
ElementTree.SubElement(this_e, _namespace_element_name(tag_name=new_key, xmlns=xmlns))
SubElement(this_e, _namespace_element_name(tag_name=new_key, xmlns=xmlns))
continue

if new_key == '.':
Expand All @@ -390,28 +392,28 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
_array_type, nested_key = prop_info.xml_array_config
nested_key = _namespace_element_name(tag_name=nested_key, xmlns=xmlns)
if _array_type and _array_type == XmlArraySerializationType.NESTED:
nested_e = ElementTree.SubElement(this_e, new_key)
nested_e = SubElement(this_e, new_key)
else:
nested_e = this_e
for j in v:
if not prop_info.is_primitive_type() and not prop_info.is_enum:
nested_e.append(j.as_xml(view_=view_, as_string=False, element_name=nested_key, xmlns=xmlns))
elif prop_info.is_enum:
ElementTree.SubElement(nested_e, nested_key).text = str(j.value)
SubElement(nested_e, nested_key).text = str(j.value)
elif prop_info.concrete_type in (float, int):
ElementTree.SubElement(nested_e, nested_key).text = str(j)
SubElement(nested_e, nested_key).text = str(j)
elif prop_info.concrete_type is bool:
ElementTree.SubElement(nested_e, nested_key).text = str(j).lower()
SubElement(nested_e, nested_key).text = str(j).lower()
else:
# Assume type is str
ElementTree.SubElement(nested_e, nested_key).text = str(j)
SubElement(nested_e, nested_key).text = str(j)
elif prop_info.custom_type:
if prop_info.is_helper_type():
ElementTree.SubElement(this_e, new_key).text = str(prop_info.custom_type.serialize(v))
SubElement(this_e, new_key).text = str(prop_info.custom_type.serialize(v))
else:
ElementTree.SubElement(this_e, new_key).text = str(prop_info.custom_type(v))
SubElement(this_e, new_key).text = str(prop_info.custom_type(v))
elif prop_info.is_enum:
ElementTree.SubElement(this_e, new_key).text = str(v.value)
SubElement(this_e, new_key).text = str(v.value)
elif not prop_info.is_primitive_type():
global_klass_name = f'{prop_info.concrete_type.__module__}.{prop_info.concrete_type.__name__}'
if global_klass_name in ObjectMetadataLibrary.klass_mappings:
Expand All @@ -420,24 +422,24 @@ def _as_xml(self: _T, view_: Optional[Type[_T]] = None, as_string: bool = True,
else:
# Handle properties that have a type that is not a Python Primitive (e.g. int, float, str)
if prop_info.string_format:
ElementTree.SubElement(this_e, new_key).text = f'{v:{prop_info.string_format}}'
SubElement(this_e, new_key).text = f'{v:{prop_info.string_format}}'
else:
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)
elif prop_info.concrete_type in (float, int):
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)
elif prop_info.concrete_type is bool:
ElementTree.SubElement(this_e, new_key).text = str(v).lower()
SubElement(this_e, new_key).text = str(v).lower()
else:
# Assume type is str
ElementTree.SubElement(this_e, new_key).text = str(v)
SubElement(this_e, new_key).text = str(v)

if as_string:
return ElementTree.tostring(this_e, 'unicode')
return cast(Element, SafeElementTree.tostring(this_e, 'unicode'))
else:
return this_e


def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, ElementTree.Element],
def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, Element],
default_namespace: Optional[str] = None) -> object:
logging.debug(f'Rendering XML from {type(data)} to {cls}...')
klass = ObjectMetadataLibrary.klass_mappings.get(f'{cls.__module__}.{cls.__qualname__}', None)
Expand All @@ -448,11 +450,12 @@ def _from_xml(cls: Type[_T], data: Union[TextIOWrapper, ElementTree.Element],
klass_properties = ObjectMetadataLibrary.klass_property_mappings.get(f'{cls.__module__}.{cls.__qualname__}', {})

if isinstance(data, TextIOWrapper):
data = ElementTree.fromstring(data.read())
data = cast(Element, SafeElementTree.fromstring(data.read()))

if default_namespace is None:
_namespaces = dict([node for _, node in ElementTree.iterparse(StringIO(ElementTree.tostring(data, 'unicode')),
events=['start-ns'])])
_namespaces = dict([node for _, node in
SafeElementTree.iterparse(StringIO(SafeElementTree.tostring(data, 'unicode')),
events=['start-ns'])])
if 'ns0' in _namespaces:
default_namespace = _namespaces['ns0']
else:
Expand Down
10 changes: 5 additions & 5 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

import json
import os
import xml
from typing import Any
from unittest import TestCase

import lxml # type: ignore
from defusedxml import ElementTree as SafeElementTree # type: ignore
from xmldiff import main # type: ignore
from xmldiff.actions import MoveNode # type: ignore

Expand All @@ -48,12 +48,12 @@ def assertEqualJson(self, a: str, b: str) -> None:
)

def assertEqualXml(self, a: str, b: str) -> None:
a = xml.etree.ElementTree.tostring(
xml.etree.ElementTree.fromstring(a, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
a = SafeElementTree.tostring(
SafeElementTree.fromstring(a, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
'unicode'
)
b = xml.etree.ElementTree.tostring(
xml.etree.ElementTree.fromstring(b, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
b = SafeElementTree.tostring(
SafeElementTree.fromstring(b, lxml.etree.XMLParser(remove_blank_text=True, remove_comments=True)),
'unicode'
)
diff_results = main.diff_texts(a, b, diff_options={'F': 0.5})
Expand Down
15 changes: 8 additions & 7 deletions tests/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

import logging
import os
from xml.etree import ElementTree

from defusedxml import ElementTree as SafeElementTree

from serializable.formatters import (
CamelCasePropertyNameFormatter,
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_serialize_tfp_sc1(self) -> None:
def test_deserialize_tfp_cc1(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v1.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -87,7 +88,7 @@ def test_deserialize_tfp_cc1(self) -> None:
def test_deserialize_tfp_cc1_v2(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v2.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject.title, book.title)
self.assertEqual(ThePhoenixProject.isbn, book.isbn)
self.assertEqual(ThePhoenixProject.edition, book.edition)
Expand All @@ -101,7 +102,7 @@ def test_deserialize_tfp_cc1_v2(self) -> None:
def test_deserialize_tfp_cc1_v3(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v3.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -115,7 +116,7 @@ def test_deserialize_tfp_cc1_v3(self) -> None:
def test_deserialize_tfp_cc1_v4(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-1-v4.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject.title, book.title)
self.assertEqual(ThePhoenixProject.isbn, book.isbn)
self.assertEqual(ThePhoenixProject.edition, book.edition)
Expand All @@ -129,7 +130,7 @@ def test_deserialize_tfp_cc1_v4(self) -> None:
def test_deserialize_tfp_cc1_with_ignored(self) -> None:
CurrentFormatter.formatter = CamelCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-camel-case-with-ignored.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand All @@ -153,7 +154,7 @@ def test_deserialize_tfp_kc1(self) -> None:
def test_deserialize_tfp_sc1(self) -> None:
CurrentFormatter.formatter = SnakeCasePropertyNameFormatter
with open(os.path.join(FIXTURES_DIRECTORY, 'the-phoenix-project-snake-case-1.xml')) as input_xml:
book: Book = Book.from_xml(data=ElementTree.fromstring(input_xml.read()))
book: Book = Book.from_xml(data=SafeElementTree.fromstring(input_xml.read()))
self.assertEqual(ThePhoenixProject_v1.title, book.title)
self.assertEqual(ThePhoenixProject_v1.isbn, book.isbn)
self.assertEqual(ThePhoenixProject_v1.edition, book.edition)
Expand Down