Skip to content

Commit

Permalink
Merge pull request #1354 from lark-parser/typing_oct2023
Browse files Browse the repository at this point in the history
Typing fixes. Mypy now produces 0 type errors
  • Loading branch information
erezsh authored Oct 21, 2023
2 parents ba5ae31 + 46df29b commit 44483c9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
43 changes: 23 additions & 20 deletions lark/load_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import pkgutil
from ast import literal_eval
from contextlib import suppress
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence
from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence, Generator

from .utils import bfs, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors, OrderedSet
from .lexer import Token, TerminalDef, PatternStr, PatternRE
from .lexer import Token, TerminalDef, PatternStr, PatternRE, Pattern

from .parse_tree_builder import ParseTreeBuilder
from .parser_frontends import ParsingFrontend
Expand Down Expand Up @@ -195,10 +195,10 @@


class FindRuleSize(Transformer):
def __init__(self, keep_all_tokens):
def __init__(self, keep_all_tokens: bool):
self.keep_all_tokens = keep_all_tokens

def _will_not_get_removed(self, sym):
def _will_not_get_removed(self, sym: Symbol) -> bool:
if isinstance(sym, NonTerminal):
return not sym.name.startswith('_')
if isinstance(sym, Terminal):
Expand All @@ -207,7 +207,7 @@ def _will_not_get_removed(self, sym):
return False
assert False, sym

def _args_as_int(self, args):
def _args_as_int(self, args: List[Union[int, Symbol]]) -> Generator[int, None, None]:
for a in args:
if isinstance(a, int):
yield a
Expand All @@ -216,10 +216,10 @@ def _args_as_int(self, args):
else:
assert False

def expansion(self, args):
def expansion(self, args) -> int:
return sum(self._args_as_int(args))

def expansions(self, args):
def expansions(self, args) -> int:
return max(self._args_as_int(args))


Expand All @@ -232,7 +232,7 @@ def __init__(self):
self.i = 0
self.rule_options = None

def _name_rule(self, inner):
def _name_rule(self, inner: str):
new_name = '__%s_%s_%d' % (self.prefix, inner, self.i)
self.i += 1
return new_name
Expand All @@ -243,7 +243,7 @@ def _add_rule(self, key, name, expansions):
self.rules_cache[key] = t
return t

def _add_recurse_rule(self, type_, expr):
def _add_recurse_rule(self, type_: str, expr: Tree):
try:
return self.rules_cache[expr]
except KeyError:
Expand Down Expand Up @@ -312,7 +312,7 @@ def _add_repeat_opt_rule(self, a, b, target, target_opt, atom):
])
return self._add_rule(key, new_name, tree)

def _generate_repeats(self, rule, mn, mx):
def _generate_repeats(self, rule: Tree, mn: int, mx: int):
"""Generates a rule tree that repeats ``rule`` exactly between ``mn`` to ``mx`` times.
"""
# For a small number of repeats, we can take the naive approach
Expand Down Expand Up @@ -343,7 +343,7 @@ def _generate_repeats(self, rule, mn, mx):

return ST('expansions', [ST('expansion', [mn_target] + [diff_opt_target])])

def expr(self, rule, op, *args):
def expr(self, rule: Tree, op: Token, *args):
if op.value == '?':
empty = ST('expansion', [])
return ST('expansions', [rule, empty])
Expand Down Expand Up @@ -372,7 +372,7 @@ def expr(self, rule, op, *args):

assert False, op

def maybe(self, rule):
def maybe(self, rule: Tree):
keep_all_tokens = self.rule_options and self.rule_options.keep_all_tokens
rule_size = FindRuleSize(keep_all_tokens).transform(rule)
empty = ST('expansion', [_EMPTY] * rule_size)
Expand All @@ -382,11 +382,11 @@ def maybe(self, rule):
class SimplifyRule_Visitor(Visitor):

@staticmethod
def _flatten(tree):
def _flatten(tree: Tree):
while tree.expand_kids_by_data(tree.data):
pass

def expansion(self, tree):
def expansion(self, tree: Tree):
# rules_list unpacking
# a : b (c|d) e
# -->
Expand Down Expand Up @@ -417,7 +417,7 @@ def alias(self, tree):
tree.data = 'expansions'
tree.children = aliases

def expansions(self, tree):
def expansions(self, tree: Tree):
self._flatten(tree)
# Ensure all children are unique
if len(set(tree.children)) != len(tree.children):
Expand Down Expand Up @@ -610,23 +610,25 @@ def range(self, start, end):
return ST('pattern', [PatternRE(regexp)])


def _make_joined_pattern(regexp, flags_set):
def _make_joined_pattern(regexp, flags_set) -> PatternRE:
return PatternRE(regexp, ())

class TerminalTreeToPattern(Transformer_NonRecursive):
def pattern(self, ps):
p ,= ps
return p

def expansion(self, items):
assert items
def expansion(self, items: List[Pattern]) -> Pattern:
if not items:
return PatternStr('')

if len(items) == 1:
return items[0]

pattern = ''.join(i.to_regexp() for i in items)
return _make_joined_pattern(pattern, {i.flags for i in items})

def expansions(self, exps):
def expansions(self, exps: List[Pattern]) -> Pattern:
if len(exps) == 1:
return exps[0]

Expand All @@ -637,7 +639,8 @@ def expansions(self, exps):
pattern = '(?:%s)' % ('|'.join(i.to_regexp() for i in exps))
return _make_joined_pattern(pattern, {i.flags for i in exps})

def expr(self, args):
def expr(self, args) -> Pattern:
inner: Pattern
inner, op = args[:2]
if op == '~':
if len(args) == 3:
Expand Down
5 changes: 3 additions & 2 deletions lark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def is_id_start(s: str) -> bool:
return _test_unicode_category(s, _ID_START)


def dedup_list(l: List[T]) -> List[T]:
def dedup_list(l: Sequence[T]) -> List[T]:
"""Given a list (l) will removing duplicates from the list,
preserving the original order of the list. Assumes that
the list entries are hashable."""
Expand Down Expand Up @@ -231,7 +231,8 @@ def combine_alternatives(lists):
return list(product(*lists))

try:
import atomicwrites
# atomicwrites doesn't have type bindings
import atomicwrites # type: ignore[import]
_has_atomicwrites = True
except ImportError:
_has_atomicwrites = False
Expand Down
4 changes: 2 additions & 2 deletions lark/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __mul__(
return TransformerChain(*self.transformers + (other,))


class Transformer_InPlace(Transformer):
class Transformer_InPlace(Transformer[_Leaf_T, _Return_T]):
"""Same as Transformer, but non-recursive, and changes the tree in-place instead of returning new instances
Useful for huge trees. Conservative in memory.
Expand All @@ -282,7 +282,7 @@ def transform(self, tree: Tree[_Leaf_T]) -> _Return_T:
return self._transform_tree(tree)


class Transformer_NonRecursive(Transformer):
class Transformer_NonRecursive(Transformer[_Leaf_T, _Return_T]):
"""Same as Transformer but non-recursive.
Like Transformer, it doesn't change the original tree.
Expand Down

0 comments on commit 44483c9

Please sign in to comment.