Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ngcodegen][model] use topological sort for order of model classes
Browse files Browse the repository at this point in the history
apalala committed Dec 10, 2023
1 parent 908549f commit ad1bc9c
Showing 3 changed files with 61 additions and 20 deletions.
46 changes: 29 additions & 17 deletions tatsu/ngcodegen/objectmodel.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from .. import grammars, objectmodel
from ..mixins.indent import IndentPrintMixin
from ..util import compress_seq, safe_name
from ..util.misc import topological_sort

HEADER = """\
#!/usr/bin/env python3
@@ -70,42 +71,53 @@ def generate_model(self, grammar: grammars.Grammar):
rule.name: self._base_class_specs(rule)
for rule in grammar.rules
}
rule_specs = {name: specs for name, specs in rule_specs.items() if specs}

all_base_spec = {
s.class_name: s
s.class_name: s.base
for specs in rule_specs.values()
for s in specs
}
base = self._model_base_class_name()
all_base_spec[base] = BaseClassSpec(base, base_type_name)
all_base_spec[base] = base_type_name

base_classes = []
for s in all_base_spec.values():
if s.base not in base_classes:
base_classes.append(s.base)
all_model_names = list(reversed(all_base_spec.keys()))
all_specs = {
(s.class_name, s.base)
for specs in rule_specs.values()
for s in specs
}

for base_name in base_classes[:-1]:
self._gen_base_class(all_base_spec[base_name])
self.print('#', all_specs)
self.print('#', all_model_names)
all_model_names = topological_sort(all_model_names, all_specs)
self.print('#', all_model_names)
model_to_rule = {
rule_specs[name][0].class_name: rule
for name, rule in rule_index.items()
if name in rule_specs
}

for model_name, rule in rule_index.items():
self._gen_rule_class(
rule,
rule_specs[model_name],
)
for model_name in all_model_names:
if rule := model_to_rule.get(model_name):
self._gen_rule_class(rule, rule_specs[rule.name])
else:
self._gen_base_class(model_name, all_base_spec.get(model_name))

return self.printed_text()

def _model_base_class_name(self):
return f'ModelBase'

def _gen_base_class(self, spec: BaseClassSpec):
def _gen_base_class(self, class_name: str, base: str | None):
self.print()
self.print()
self.print('@dataclass(eq=False)')
if spec.base:
self.print(f'class {spec.class_name}({spec.base}):')
if base:
self.print(f'class {class_name}({base}): # base')
else:
# FIXME: this cannot happen as base_type is the final base
self.print(f'class {spec.class_name}:')
self.print(f'class {class_name}:')
with self.indent():
self.print('pass')

34 changes: 31 additions & 3 deletions tatsu/util/misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations
from typing import TypeVar

import re

from ._common import RETYPE


_T = TypeVar('_T')

_undefined = object() # unique object for when None is not a good default


@@ -71,13 +75,37 @@ def findalliter(pattern, string, pos=None, endpos=None, flags=0):
yield match_to_find(m)


def findfirst(
pattern, string, pos=None, endpos=None, flags=0, default=_undefined,
):
def findfirst( pattern, string, pos=None, endpos=None, flags=0, default=_undefined):
"""
Avoids using the inefficient findall(...)[0], or first(findall(...))
"""
return first(
findalliter(pattern, string, pos=pos, endpos=endpos, flags=flags),
default=default,
)


def topological_sort(nodes: list[_T], order: set[tuple[_T, _T]]) -> list[_T]:
# https://en.wikipedia.org/wiki/Topological_sorting

order = set(order)
result = [] # Empty list that will contain the sorted elements

pending = [ # Set of all nodes with no incoming edge
n for n in nodes
if not any(x for (x, y) in order if y == n)
]
while pending:
n = pending.pop()
result.insert(0, n)
outgoing = {m for (x, m) in order if x == n}
for m in outgoing:
order.remove((n, m))
if not any(x for x, y in order if y == m):
# m has no other incoming edges then
pending.append(m)

if order:
raise ValueError('There are cycles in the graph')

return result # a topologically sorted list
1 change: 1 addition & 0 deletions test/grammar/semantics_test.py
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@ def test_builder_basetype_codegen(self):
from tatsu.tool import to_python_model

src = to_python_model(grammar, base_type=MyNode)
print(src)

globals = {}
exec(src, globals) # pylint: disable=W0122

0 comments on commit ad1bc9c

Please sign in to comment.