-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Changed unit tests using unittest library --------- Co-authored-by: hgKang02 <[email protected]>
- Loading branch information
1 parent
9613274
commit e496555
Showing
11 changed files
with
1,669 additions
and
857 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,4 @@ syncode/core/__pycache__ | |
.vscode/ | ||
tmp* | ||
cache/ | ||
.ipynb_checkpoints/ | ||
.ipynb_checkpoints/ |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,28 @@ | ||
import sys, os | ||
import unittest | ||
import os | ||
import sys | ||
|
||
# Adjusting the path so the modules can be imported correctly | ||
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') | ||
|
||
from syncode.parsers import create_parser | ||
from syncode.common import run_tests | ||
from syncode.parse_result import AcceptSequence | ||
from syncode.parsers.grammars.grammar import Grammar | ||
|
||
class TestParser(unittest.TestCase): | ||
def test_parser(self): | ||
inc_parser = create_parser(Grammar('calc')) | ||
partial_code = "113 + 235 + 17" | ||
r = inc_parser.get_acceptable_next_terminals(partial_code) | ||
self.assertEqual(r.remainder, '17') | ||
self.assertIn(AcceptSequence(['NUMBER', 'PLUS']), r.accept_sequences) | ||
self.assertIn(AcceptSequence(['NUMBER', 'STAR']), r.accept_sequences) | ||
self.assertIn(AcceptSequence(['LPAR']), r.accept_sequences) | ||
|
||
def test_parser(): | ||
inc_parser = create_parser(Grammar('calc')) | ||
partial_code = "113 + 235 + 17" | ||
r = inc_parser.get_acceptable_next_terminals(partial_code) | ||
assert r.remainder == '17' | ||
assert AcceptSequence(['NUMBER', 'PLUS']) in r.accept_sequences | ||
assert AcceptSequence(['NUMBER', 'STAR']) in r.accept_sequences | ||
assert AcceptSequence(['LPAR']) in r.accept_sequences | ||
|
||
def test_parser2(): | ||
inc_parser = create_parser(Grammar('calc')) | ||
partial_code = "11333" | ||
r = inc_parser.get_acceptable_next_terminals(partial_code) | ||
assert r.remainder == '11333' | ||
assert AcceptSequence(['NUMBER', 'PLUS']) in r.accept_sequences | ||
def test_parser2(self): | ||
inc_parser = create_parser(Grammar('calc')) | ||
partial_code = "11333" | ||
r = inc_parser.get_acceptable_next_terminals(partial_code) | ||
self.assertEqual(r.remainder, '11333') | ||
self.assertIn(AcceptSequence(['NUMBER', 'PLUS']), r.accept_sequences) | ||
|
||
run_tests([test_parser, test_parser2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,55 @@ | ||
import sys, os | ||
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') | ||
import sys | ||
import os | ||
import time | ||
import unittest | ||
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../') | ||
import syncode.common as common | ||
from syncode.parsers.incremental_parser import ParseResult | ||
from syncode.parse_result import AcceptSequence, RemainderState | ||
from syncode.dfa_mask_store import DFAMaskStore | ||
from syncode.parsers.grammars import Grammar | ||
from syncode.parsers.grammars.grammar import Grammar | ||
|
||
# model = 'Salesforce/codegen-350M-multi' | ||
# model = 'WizardLM/WizardCoder-1B-V1.0' | ||
model = 'Llama-7b' | ||
# Initialize these outside the test class if they're shared across tests | ||
model = 'deepseek-ai/deepseek-coder-6.7b-instruct' | ||
# model = 'Llama-7b' | ||
tokenizer = common.load_tokenizer(model) | ||
dfa_mask = DFAMaskStore.load_dfa_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=True, logger=common.EmptyLogger()) | ||
|
||
def test_dfa_mask(): | ||
query_start_time = time.time() | ||
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE) | ||
dfa_mask.get_overapprox_tokens_mask(r) # 0.02 seconds for mask | ||
print(f'Time taken for mask query:', time.time() - query_start_time, flush=True) | ||
|
||
query_start_time = time.time() | ||
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE) | ||
dfa_mask.get_overapprox_tokens_mask(r, get_list=True) # 10^-4 seconds for list | ||
print(f'Time taken for list query:', time.time() - query_start_time, flush=True) | ||
print(dfa_mask.get_overapprox_tokens_mask(r, get_list=True)) | ||
assert all(t in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) for t in [' +', ' +=', ' ++']) | ||
|
||
def test_dfa_mask2(): | ||
r = ParseResult({AcceptSequence(['EOS'])}, '\n // 1.', RemainderState.MAYBE_COMPLETE) | ||
print(len(dfa_mask.get_overapprox_tokens_mask(r, get_list=True))) | ||
assert len(dfa_mask.get_overapprox_tokens_mask(r, get_list=True)) > 32000 | ||
|
||
# TODO: Fix | ||
def test_dfa_mask3(): | ||
r = ParseResult({AcceptSequence(['__ANON_14'])}, '', RemainderState.COMPLETE) | ||
print(dfa_mask.get_overapprox_tokens_mask(r, get_list=True)) | ||
# assert ":=" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
|
||
def test_dfa_mask4(): | ||
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, '', RemainderState.COMPLETE) | ||
assert "\t" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
|
||
def test_dfa_mask5(): | ||
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, '{', RemainderState.MAYBE_COMPLETE) | ||
assert "\t" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
|
||
def test_dfa_mask6(): | ||
# TODO: imprecision | ||
r = ParseResult({AcceptSequence(['NAME'])}, 'for', RemainderState.MAYBE_COMPLETE) | ||
assert " {" in dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
|
||
|
||
tests = [test_dfa_mask, test_dfa_mask2, test_dfa_mask3, test_dfa_mask4, test_dfa_mask5, test_dfa_mask6] | ||
common.run_tests(tests) | ||
class TestDFAMask(unittest.TestCase): | ||
def test_dfa_mask(self): | ||
query_start_time = time.time() | ||
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE) | ||
dfa_mask.get_overapprox_tokens_mask(r) # This is just to run the function, assuming you're checking time | ||
# self.assertLess(time.time() - query_start_time, 0.02, "Mask query took too long") | ||
|
||
query_start_time = time.time() | ||
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, '1', RemainderState.MAYBE_COMPLETE) | ||
dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
# self.assertLess(time.time() - query_start_time, 10**-4, "List query took too long") | ||
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
for token in [' +', ' +=', ' ++']: | ||
self.assertIn(token, result_list, f"{token} not found in result list") | ||
|
||
def test_dfa_mask2(self): | ||
r = ParseResult({AcceptSequence(['EOS'])}, '\n // 1.', RemainderState.MAYBE_COMPLETE) | ||
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected") | ||
|
||
def test_dfa_mask3(self): | ||
r = ParseResult({AcceptSequence(['__ANON_14'])}, '', RemainderState.COMPLETE) | ||
result_list = dfa_mask.get_overapprox_tokens_mask(r, get_list=True) | ||
# Uncomment the following line if you want to assert presence of specific tokens | ||
# self.assertIn(":=", result_list, ":= not found in result list") | ||
|
||
def test_dfa_mask4(self): | ||
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, '', RemainderState.COMPLETE) | ||
self.assertIn("\t", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Tab character not found in result list") | ||
|
||
def test_dfa_mask5(self): | ||
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, '{', RemainderState.MAYBE_COMPLETE) | ||
self.assertIn("\t", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Tab character not found in result list") | ||
|
||
def test_dfa_mask6(self): | ||
r = ParseResult({AcceptSequence(['NAME'])}, 'for', RemainderState.MAYBE_COMPLETE) | ||
self.assertIn(" {", dfa_mask.get_overapprox_tokens_mask(r, get_list=True), "Opening brace not found in result list") | ||
|
Oops, something went wrong.