diff --git a/tests/test_main.py b/tests/test_main.py index 476b77b..271be33 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -127,6 +127,8 @@ class TestMain: ), ) ), + # Malformed MIME type + (b"...", (b"javascript charset=UTF-8",), True, False, None, None, b"text/plain"), ], ) def test_extract_mime( diff --git a/xtractmime/__init__.py b/xtractmime/__init__.py index d68247c..9de8882 100644 --- a/xtractmime/__init__.py +++ b/xtractmime/__init__.py @@ -1,4 +1,5 @@ __version__ = "0.2.0" +import re from typing import Optional, Set, Tuple from xtractmime._patterns import _APACHE_TYPES, BINARY_BYTES, WHITESPACE_BYTES from xtractmime._utils import ( @@ -187,6 +188,27 @@ def _sniff_mislabled_feed(input_bytes: bytes, supplied_type: bytes) -> Optional[ return supplied_type +_TOKEN = br"^\s*[-!#$%&'*+.0-9A-Z^_`a-z{|}~]+\s*$" + + +def _is_valid_mime_type(mime_type): + """Return True if the specified MIME type is valid as per RFC 2045, or + False otherwise. + + Only the type and subtype are validated, parameters are ignored. + """ + parts = mime_type.split(b"/", maxsplit=1) + if len(parts) < 2: + return False + _type, subtype_and_params = parts + if not re.match(_TOKEN, _type): + return False + subtype = subtype_and_params.split(b";", maxsplit=1)[0] + if not re.match(_TOKEN, subtype): + return False + return True + + def extract_mime( body: bytes, *, @@ -199,6 +221,8 @@ def extract_mime( extra_types = extra_types or tuple() supplied_type = content_types[-1] if content_types else b"" check_for_apache = http_origin and supplied_type in _APACHE_TYPES + if not _is_valid_mime_type(supplied_type): + supplied_type = b"" supplied_type = supplied_type.split(b";")[0].strip().lower() resource_header = memoryview(body)[:RESOURCE_HEADER_BUFFER_LENGTH]