Skip to content

Commit

Permalink
Unit test (#40)
Browse files Browse the repository at this point in the history
Changed unit tests using unittest library

---------

Co-authored-by: hgKang02 <[email protected]>
  • Loading branch information
shubhamugare and hgKang02 authored Feb 27, 2024
1 parent 9613274 commit e496555
Show file tree
Hide file tree
Showing 11 changed files with 1,669 additions and 857 deletions.
13 changes: 6 additions & 7 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ jobs:
key: files-${{ hashFiles('syncode/parsers/grammars/python_grammar.lark', 'syncode/dfa_mask_store.py') }}
- name: Run Tests
run: |
python tests/test_grammar_python.py
python tests/test_grammar_go.py
python tests/test_calc.py
python tests/test_lr_parser.py
python tests/test_language_model.py
python tests/test_dfa_mask_python.py --independent
python tests/test_syncode.py
python3 -m unittest tests.test_calc
python3 -m unittest tests.test_grammar_go
python3 -m unittest tests.test_grammar_python
python3 -m unittest tests.test_language_model
python3 -m unittest tests.test_lr_parser
python3 -m unittest tests.test_syncode
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ syncode/core/__pycache__
.vscode/
tmp*
cache/
.ipynb_checkpoints/
.ipynb_checkpoints/
824 changes: 824 additions & 0 deletions syncode/tfgeneration/tfgeneration.py

Large diffs are not rendered by default.

39 changes: 21 additions & 18 deletions tests/test_calc.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import sys, os
import unittest
import os
import sys

# Adjusting the path so the modules can be imported correctly
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')

from syncode.parsers import create_parser
from syncode.common import run_tests
from syncode.parse_result import AcceptSequence
from syncode.parsers.grammars.grammar import Grammar

class TestParser(unittest.TestCase):
def test_parser(self):
inc_parser = create_parser(Grammar('calc'))
partial_code = "113 + 235 + 17"
r = inc_parser.get_acceptable_next_terminals(partial_code)
self.assertEqual(r.remainder, '17')
self.assertIn(AcceptSequence(['NUMBER', 'PLUS']), r.accept_sequences)
self.assertIn(AcceptSequence(['NUMBER', 'STAR']), r.accept_sequences)
self.assertIn(AcceptSequence(['LPAR']), r.accept_sequences)

def test_parser():
inc_parser = create_parser(Grammar('calc'))
partial_code = "113 + 235 + 17"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == '17'
assert AcceptSequence(['NUMBER', 'PLUS']) in r.accept_sequences
assert AcceptSequence(['NUMBER', 'STAR']) in r.accept_sequences
assert AcceptSequence(['LPAR']) in r.accept_sequences

def test_parser2():
inc_parser = create_parser(Grammar('calc'))
partial_code = "11333"
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == '11333'
assert AcceptSequence(['NUMBER', 'PLUS']) in r.accept_sequences
def test_parser2(self):
inc_parser = create_parser(Grammar('calc'))
partial_code = "11333"
r = inc_parser.get_acceptable_next_terminals(partial_code)
self.assertEqual(r.remainder, '11333')
self.assertIn(AcceptSequence(['NUMBER', 'PLUS']), r.accept_sequences)

run_tests([test_parser, test_parser2])
91 changes: 45 additions & 46 deletions tests/test_dfa_mask_go.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,55 @@
import sys, os
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
import sys
import os
import time
import unittest
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../')
import syncode.common as common
from syncode.parsers.incremental_parser import ParseResult
from syncode.parse_result import AcceptSequence, RemainderState
from syncode.dfa_mask_store import DFAMaskStore
from syncode.parsers.grammars import Grammar
from syncode.parsers.grammars.grammar import Grammar

# model = 'Salesforce/codegen-350M-multi'
# model = 'WizardLM/WizardCoder-1B-V1.0'
model = 'Llama-7b'
# Initialize these outside the test class if they're shared across tests
model = 'deepseek-ai/deepseek-coder-6.7b-instruct'
# model = 'Llama-7b'
tokenizer = common.load_tokenizer(model)
dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger())

def test_dfa_mask():
query_start_time = time.time()
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
dfa_mask.get_overapprox_tokens_mask(r) # 0.02 seconds for mask
print(f'Time taken for mask query:', time.time() - query_start_time, flush=True)

query_start_time = time.time()
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
dfa_mask.get_overapprox_tokens_mask(r, get_list=True) # 10^-4 seconds for list
print(f'Time taken for list query:', time.time() - query_start_time, flush=True)
print(dfa_mask.get_overapprox_tokens_mask(r, get_list=True))
assert all(t in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) for t in [' +', ' +=', ' ++'])

def test_dfa_mask2():
r = ParseResult({AcceptSequence(['EOS'])}, '\n // 1.', RemainderState.MAYBE_COMPLETE)
print(len(dfa_mask.get_overapprox_tokens_mask(r, get_list=True)))
assert len(dfa_mask.get_overapprox_tokens_mask(r, get_list=True)) > 32000

# TODO: Fix
def test_dfa_mask3():
r = ParseResult({AcceptSequence(['__ANON_14'])}, '', RemainderState.COMPLETE)
print(dfa_mask.get_overapprox_tokens_mask(r, get_list=True))
# assert ":=" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True)

def test_dfa_mask4():
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, '', RemainderState.COMPLETE)
assert "\t" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True)

def test_dfa_mask5():
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, '{', RemainderState.MAYBE_COMPLETE)
assert "\t" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True)

def test_dfa_mask6():
# TODO: imprecision
r = ParseResult({AcceptSequence(['NAME'])}, 'for', RemainderState.MAYBE_COMPLETE)
assert " {" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True)


tests = [test_dfa_mask, test_dfa_mask2, test_dfa_mask3, test_dfa_mask4, test_dfa_mask5, test_dfa_mask6]
common.run_tests(tests)
class TestDFAMask(unittest.TestCase):
def test_dfa_mask(self):
query_start_time = time.time()
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
dfa_mask.get_overapprox_tokens_mask(r) # This is just to run the function, assuming you're checking time
# self.assertLess(time.time() - query_start_time, 0.02, "Mask query took too long")

query_start_time = time.time()
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE)
dfa_mask.get_overapprox_tokens_mask(r, get_list=True)
# self.assertLess(time.time() - query_start_time, 10**-4, "List query took too long")
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True)
for token in [' +', ' +=', ' ++']:
self.assertIn(token, result_list, f"{token} not found in result list")

def test_dfa_mask2(self):
r = ParseResult({AcceptSequence(['EOS'])}, '\n // 1.', RemainderState.MAYBE_COMPLETE)
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True)
self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected")

def test_dfa_mask3(self):
r = ParseResult({AcceptSequence(['__ANON_14'])}, '', RemainderState.COMPLETE)
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True)
# Uncomment the following line if you want to assert presence of specific tokens
# self.assertIn(":=", result_list, ":= not found in result list")

def test_dfa_mask4(self):
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, '', RemainderState.COMPLETE)
self.assertIn("\t", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Tab character not found in result list")

def test_dfa_mask5(self):
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, '{', RemainderState.MAYBE_COMPLETE)
self.assertIn("\t", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Tab character not found in result list")

def test_dfa_mask6(self):
r = ParseResult({AcceptSequence(['NAME'])}, 'for', RemainderState.MAYBE_COMPLETE)
self.assertIn(" {", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Opening brace not found in result list")

Loading

0 comments on commit e496555

Please sign in to comment.