Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 56 additions & 149 deletions docling_core/experimental/idoctags.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""Define classes for DocTags serialization."""

import copy
import html
import re
from enum import Enum
from html.parser import HTMLParser
from itertools import groupby
from typing import Any, ClassVar, Final, Optional, cast
from xml.dom.minidom import Element, Text, parseString
Expand Down Expand Up @@ -373,8 +371,9 @@ class IDocTagsToken(str, Enum):
INLINE = "inline"

# Formatting
BOLD = "bold" # instead of "strong"
ITALIC = "italic" # instead of "em"
BOLD = "bold"
ITALIC = "italic"
UNDERLINE = "underline"
STRIKETHROUGH = "strikethrough"
SUPERSCRIPT = "superscript"
SUBSCRIPT = "subscript"
Expand Down Expand Up @@ -981,6 +980,15 @@ class IDocTagsSerializationMode(str, Enum):
LLM_FRIENDLY = "llm_friendly"


class EscapeMode(str, Enum):
"""XML escape mode for IDocTags output."""

CDATA_ALWAYS = "cdata_always" # wrap all text in CDATA
CDATA_WHEN_NEEDED = (
"cdata_when_needed" # wrap text in CDATA only if it contains special characters
)


class IDocTagsParams(CommonParams):
"""IDocTags-specific serialization parameters independent of DocTags."""

Expand All @@ -1002,7 +1010,7 @@ class IDocTagsParams(CommonParams):
# Expand self-closing forms of non-self-closing tokens after pretty-printing
preserve_empty_non_selfclosing: bool = True
# XML compliance: escape special characters in text content
xml_compliant: bool = False
escape_mode: EscapeMode = EscapeMode.CDATA_WHEN_NEEDED


def _get_delim(*, params: IDocTagsParams) -> str:
Expand All @@ -1014,84 +1022,13 @@ def _get_delim(*, params: IDocTagsParams) -> str:
raise RuntimeError(f"Unknown IDocTags mode: {params.mode}")


class _WhitelistHTMLParser(HTMLParser):
"""XML-safe sanitizer that preserves only specific IDocTags formatting and content tags.

Preserves these tags (attributes are stripped):
bold, italic, strikethrough, superscript, subscript, inline, text, code, formula, facets.
All other tags are escaped literally.
"""

# Allowed formatting and content tags
_ALLOWED = {
IDocTagsToken.BOLD.value,
IDocTagsToken.ITALIC.value,
IDocTagsToken.STRIKETHROUGH.value,
IDocTagsToken.SUPERSCRIPT.value,
IDocTagsToken.SUBSCRIPT.value,
IDocTagsToken.INLINE.value,
IDocTagsToken.TEXT.value,
IDocTagsToken.CODE.value,
IDocTagsToken.FORMULA.value,
IDocTagsToken.FACETS.value,
}

def __init__(self):
super().__init__(convert_charrefs=False)
self.out = []

def handle_starttag(self, tag, attrs):
if tag in self._ALLOWED:
self.out.append(f"<{tag}>")
else:
# Escape disallowed tags literally
self.out.append(html.escape(self.get_starttag_text(), quote=False))

def handle_endtag(self, tag):
if tag in self._ALLOWED:
self.out.append(f"</{tag}>")
else:
self.out.append(html.escape(f"</{tag}>", quote=False))

def handle_startendtag(self, tag, attrs):
if tag in self._ALLOWED:
self.out.append(f"<{tag}></{tag}>")
else:
self.out.append(html.escape(self.get_starttag_text(), quote=False))

def handle_data(self, data):
self.out.append(html.escape(data, quote=False))

def handle_entityref(self, name):
self.out.append(f"&{name};")

def handle_charref(self, name):
self.out.append(f"&#{name};")

def handle_comment(self, data):
self.out.append(html.escape(f"<!--{data}-->", quote=False))


# def _escape_xml_text(text: str, xml_compliant: bool) -> str:
# """Escape XML special characters if xml_compliant is enabled."""
# if xml_compliant:
# return html.escape(text, quote=False)
# return text


def _escape_xml_text(text: str, xml_compliant: bool) -> str:
"""Escape text for XML while optionally preserving specific IDocTags formatting tags.

If xml_compliant=True, preserves only these tags (attributes stripped):
bold, italic, strikethrough, superscript, subscript, inline, text, code, formula, facets.
All other tags are escaped. If xml_compliant=False, returns text unchanged.
"""
if not xml_compliant:
return text
parser = _WhitelistHTMLParser()
parser.feed(text)
parser.close()
return "".join(parser.out)
def _escape_text(text: str, escape_mode: EscapeMode) -> str:
if escape_mode == EscapeMode.CDATA_ALWAYS or (
escape_mode == EscapeMode.CDATA_WHEN_NEEDED
and any(c in text for c in ['"', "'", "&", "<", ">"])
):
return f"<![CDATA[{text}]]>"
return text


class IDocTagsListSerializer(BaseModel, BaseListSerializer):
Expand Down Expand Up @@ -1333,6 +1270,12 @@ def _serialize_single_item(
elif isinstance(item, ListItem):
tok = IDocTagsToken.LIST_TEXT
wrap_open_token = f"<{tok.value}>"
elif isinstance(item, CodeItem):
tok = IDocTagsToken.CODE
if item.code_language != CodeLanguageLabel.UNKNOWN:
wrap_open_token = f'<{tok.value} {IDocTagsAttributeKey.CLASS.value}="{item.code_language.value}">'
else:
wrap_open_token = f"<{tok.value}>"
elif (
isinstance(item, TextItem) and item.label == DocItemLabel.CHECKBOX_SELECTED
):
Expand Down Expand Up @@ -1402,44 +1345,26 @@ def _serialize_single_item(
hyperlink=item.hyperlink,
)
else:
text_part = _escape_text(item.text, params.escape_mode)
text_part = doc_serializer.post_process(
text=item.text,
text=text_part,
formatting=item.formatting,
hyperlink=item.hyperlink,
)

# For code blocks, preserve language using a lightweight facets marker
# e.g., <facets>language=python</facets> before the code content.
if isinstance(item, CodeItem):
# lang = getattr(item.code_language, "value", str(item.code_language))
if item.code_language != CodeLanguageLabel.UNKNOWN:
parts.append(
_wrap(
# text=f"language={lang.lower()}",
text=item.code_language.value,
wrap_tag=IDocTagsToken.FACETS.value,
)
)
# Keep the textual code content as-is (no stripping)
else:
text_part = text_part.strip()

# Apply XML escaping if xml_compliant is enabled
text_part = _escape_xml_text(text_part, params.xml_compliant)

if text_part:
parts.append(text_part)

if params.add_caption and isinstance(item, FloatingItem):
cap_text = doc_serializer.serialize_captions(item=item, **kwargs).text
if cap_text:
cap_text = _escape_xml_text(cap_text, params.xml_compliant)
cap_text = _escape_text(cap_text, params.escape_mode)
parts.append(cap_text)

if params.add_footnote and isinstance(item, FloatingItem):
ftn_text = doc_serializer.serialize_footnotes(item=item, **kwargs).text
if ftn_text:
ftn_text = _escape_xml_text(ftn_text, params.xml_compliant)
ftn_text = _escape_text(ftn_text, params.escape_mode)
parts.append(ftn_text)

text_res = "".join(parts)
Expand Down Expand Up @@ -1496,25 +1421,25 @@ def _serialize_meta_field(
if name == MetaFieldName.SUMMARY and isinstance(
field_val, SummaryMetaField
):
escaped_text = _escape_xml_text(field_val.text, params.xml_compliant)
escaped_text = _escape_text(field_val.text, params.escape_mode)
txt = f"<summary>{escaped_text}</summary>"
elif name == MetaFieldName.DESCRIPTION and isinstance(
field_val, DescriptionMetaField
):
escaped_text = _escape_xml_text(field_val.text, params.xml_compliant)
escaped_text = _escape_text(field_val.text, params.escape_mode)
txt = f"<description>{escaped_text}</description>"
elif name == MetaFieldName.CLASSIFICATION and isinstance(
field_val, PictureClassificationMetaField
):
class_name = self._humanize_text(
field_val.get_main_prediction().class_name
)
escaped_class_name = _escape_xml_text(class_name, params.xml_compliant)
escaped_class_name = _escape_text(class_name, params.escape_mode)
txt = f"<classification>{escaped_class_name}</classification>"
elif name == MetaFieldName.MOLECULE and isinstance(
field_val, MoleculeMetaField
):
escaped_smi = _escape_xml_text(field_val.smi, params.xml_compliant)
escaped_smi = _escape_text(field_val.smi, params.escape_mode)
txt = f"<molecule>{escaped_smi}</molecule>"
elif name == MetaFieldName.TABULAR_CHART and isinstance(
field_val, TabularChartMetaField
Expand All @@ -1524,9 +1449,7 @@ def _serialize_meta_field(
# elif tmp := str(field_val or ""):
# txt = tmp
elif name not in {v.value for v in MetaFieldName}:
escaped_text = _escape_xml_text(
str(field_val or ""), params.xml_compliant
)
escaped_text = _escape_text(str(field_val or ""), params.escape_mode)
txt = _wrap(text=escaped_text, wrap_tag=name)
return txt
return None
Expand Down Expand Up @@ -1744,9 +1667,7 @@ def _emit_otsl(
parts.append(cell_loc)
if params.add_content:
# Apply XML escaping to table cell content
escaped_content = _escape_xml_text(
content, params.xml_compliant
)
escaped_content = _escape_text(content, params.escape_mode)
parts.append(escaped_content)
else:
parts.append(
Expand Down Expand Up @@ -2139,6 +2060,11 @@ def serialize_italic(self, text: str, **kwargs: Any) -> str:
"""Apply IDocTags-specific italic serialization."""
return _wrap(text=text, wrap_tag=IDocTagsToken.ITALIC.value)

@override
def serialize_underline(self, text: str, **kwargs: Any) -> str:
"""Apply IDocTags-specific underline serialization."""
return _wrap(text=text, wrap_tag=IDocTagsToken.UNDERLINE.value)

@override
def serialize_strikethrough(self, text: str, **kwargs: Any) -> str:
"""Apply IDocTags-specific strikethrough serialization."""
Expand Down Expand Up @@ -2341,44 +2267,22 @@ def _extract_code_content_and_language(
self, el: Element
) -> tuple[str, CodeLanguageLabel]:
"""Extract code content and language from a <code> element."""
lang_label = CodeLanguageLabel.UNKNOWN
try:
lang_label = CodeLanguageLabel(
el.getAttribute(IDocTagsAttributeKey.CLASS.value)
)
except ValueError:
lang_label = CodeLanguageLabel.UNKNOWN
parts: list[str] = []
for node in el.childNodes:
if isinstance(node, Text):
if node.data.strip():
parts.append(node.data)
elif isinstance(node, Element):
nm_child = node.tagName
if nm_child == IDocTagsToken.FACETS.value:
language_text = self._get_text(node).strip()
try:
lang_label = next(
lbl
for lbl in CodeLanguageLabel
if lbl.value == language_text
)
except StopIteration:
lang_label = CodeLanguageLabel.UNKNOWN

"""
facets_text = self._get_text(node).strip()
if "=" in facets_text:
key, val = facets_text.split("=", 1)
if key.strip().lower() == "language":
val_norm = val.strip().lower()
try:
lang_label = next(
lbl
for lbl in CodeLanguageLabel
if lbl.value.lower() == val_norm
)
except StopIteration:
lang_label = CodeLanguageLabel.UNKNOWN
"""
continue
if nm_child == IDocTagsToken.LOCATION.value:
continue
if nm_child == IDocTagsToken.BR.value:
elif nm_child == IDocTagsToken.BR.value:
parts.append("\n")
else:
parts.append(self._get_text(node))
Expand Down Expand Up @@ -2793,11 +2697,12 @@ def _extract_text_with_formatting(

# Mapping of format tags to Formatting attributes
format_tags = {
IDocTagsToken.BOLD.value: "bold",
IDocTagsToken.ITALIC.value: "italic",
IDocTagsToken.STRIKETHROUGH.value: "strikethrough",
IDocTagsToken.SUPERSCRIPT.value: "superscript",
IDocTagsToken.SUBSCRIPT.value: "subscript",
IDocTagsToken.BOLD,
IDocTagsToken.ITALIC,
IDocTagsToken.STRIKETHROUGH,
IDocTagsToken.UNDERLINE,
IDocTagsToken.SUPERSCRIPT,
IDocTagsToken.SUBSCRIPT,
}

if tag_name in format_tags:
Expand All @@ -2815,6 +2720,8 @@ def _extract_text_with_formatting(
child_formatting.italic = True
elif tag_name == IDocTagsToken.STRIKETHROUGH.value:
child_formatting.strikethrough = True
elif tag_name == IDocTagsToken.UNDERLINE.value:
child_formatting.underline = True
elif tag_name == IDocTagsToken.SUPERSCRIPT.value:
child_formatting.script = Script.SUPER
elif tag_name == IDocTagsToken.SUBSCRIPT.value:
Expand Down
Loading
Loading