Skip to content

Commit

Permalink
Compress mask store by only storing viable terminal sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Feb 17, 2025
1 parent e42a0f6 commit 23d608e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
71 changes: 65 additions & 6 deletions syncode/dfa_mask_store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import defaultdict
import copy, os, pickle
import time
import interegular
import torch
import regex
Expand Down Expand Up @@ -290,7 +291,8 @@ def __init__(self,
special_token_ids: Iterable=[],
indentation: bool=True,
mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
ignore_terminals: Iterable[str]=[]
ignore_terminals: Iterable[str]=[],
parse_table=None
):
self._vocab = vocab
self.special_token_ids = special_token_ids
Expand All @@ -304,7 +306,11 @@ def __init__(self,
# Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
self._lookup_table = LookupTable(vocab, special_token_ids, indentation=indentation, mode=mode)
terminal_names = [terminal.name for terminal in terminals]
self._store_overapproximate_tokens(terminal_names, vocab)

followings_terminas_map = None
if parse_table is not None:
followings_terminas_map = self._compute_following_terminals_map(terminal_names, parse_table)
self._store_overapproximate_tokens(terminal_names, vocab, followings_terminas_map)

self.indentation = indentation

Expand All @@ -321,7 +327,14 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
return ignore_whitespace

@staticmethod
def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None, mode='grammar_strict'):
def load_dfa_mask_store(
grammar: Grammar,
tokenizer,
use_cache=True,
logger=None,
mode='grammar_strict',
parse_table=None
):
'''
Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
'''
Expand All @@ -347,7 +360,17 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None
simplifications = grammar.simplifications()
os.makedirs(dfa_dir, exist_ok=True)

mask_store = DFAMaskStore(base_parser.terminals, vocab, simplifications=simplifications, special_token_ids=[tokenizer.eos_token_id], mode=mode, ignore_terminals=base_parser.ignore_tokens)
start_time = time.time()
mask_store = DFAMaskStore(
base_parser.terminals,
vocab,
simplifications=simplifications,
special_token_ids=[tokenizer.eos_token_id],
mode=mode,
ignore_terminals=base_parser.ignore_tokens,
parse_table=parse_table
)
print(f"Time taken to create DFA mask store: {time.time() - start_time} seconds", flush=True)

pickle.dump(mask_store, open(dfa_path, 'wb'))
return mask_store
Expand All @@ -356,12 +379,42 @@ def _get_default_mask(self) -> torch.Tensor:
mask = torch.zeros(len(self._vocab), dtype=torch.bool)
return mask

def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str]):
def _compute_following_terminals_map(self, terminals: Iterable[str], parse_table) -> defaultdict:
"""
From terminals, filter out terminals that cannot follow the current terminal
according to the grammar.
If in the parsing table Action[cur_terminal, parser_state] = 'shift, new_parser_state' then next terminals
are the terminals that are legal in new_parser_state.
"""
following_terminals_map = defaultdict(set)
terminals_set = set(terminals)

# We iterate through each cur_terminal:
for cur_terminal in terminals:
# We iterate through each parser_state:
for _, row in parse_table.states.items():
if cur_terminal in row:
action = row[cur_terminal]
# -> If we see a shift action to new_parser_state
if str(action[0]) == 'Shift':
new_parser_state = action[1]
for next_terminal in parse_table.states[new_parser_state]:
# Lark parse_table stores non-terminals and terminals together
if next_terminal in terminals_set:
# -> -> we add the terminals that are legal in new_parser_state
following_terminals_map[cur_terminal].add(next_terminal)

return following_terminals_map


def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterable[str], followings_terminas_map: dict=None):
"""
Stores the overapproximate tokens for each dfa state and next terminals
"""
all_dfa_states = self._dfas.states()
pbar = tqdm(total=len(all_dfa_states))

for dfa_state in all_dfa_states:
for token_idx, token in enumerate(vocab):
is_special_token = token_idx in self.special_token_ids
Expand All @@ -371,12 +424,18 @@ def _store_overapproximate_tokens(self, terminals: Iterable[str], vocab: Iterabl
self._lookup_table.dfa_state_and_next_terminal_to_tokens_add(
dfa_state, '$END', token_idx)
else:
self._process_regular_tokens(terminals, dfa_state, token_idx, token)
if followings_terminas_map is not None and dfa_state.terminal in followings_terminas_map:
following_terminals = followings_terminas_map[dfa_state.terminal]
else:
following_terminals = terminals

self._process_regular_tokens(following_terminals, dfa_state, token_idx, token)

pbar.update(1)

def _process_regular_tokens(self, terminals, dfa_state, token_idx, token):
remainder = token.replace('\t', ' ')

is_valid, remainder = self._dfas.consume_prefix(dfa_state, remainder)
if is_valid:
if remainder == '':
Expand Down
7 changes: 4 additions & 3 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,19 @@ def __init__(self,
# Ignore whitespace tokens
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)

# Create parser
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)

# Load dfa mask store
self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
grammar=self.grammar,
tokenizer=self.tokenizer,
use_cache=use_cache,
logger=self.logger,
mode=mode,
parse_table=self.inc_parser.base_parser.parser.parser._parse_table,
)

# Create parser
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)


def _log_current_status(self, partial_code, r: ParseResult):
self.logger.log_code('Partial code', partial_code)
Expand Down

0 comments on commit 23d608e

Please sign in to comment.