Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDE unt lvl comparison #6669

Merged
merged 37 commits into from
Jun 3, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
35f071c
Beta version of unt lvl comparison
Mar 1, 2023
159e367
Beta SDE utt lvl
Apr 12, 2023
aa3daf4
Alpha
Apr 17, 2023
b870540
current v
Jorjeous May 10, 2023
7bb037f
Alpha version, need to fix text diff alignment
Jorjeous May 10, 2023
f70ce94
sde unt final
Jorjeous May 17, 2023
828811f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2023
1514909
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 17, 2023
b7631c9
update docks
Jorjeous May 17, 2023
54c4612
del unused
Jorjeous May 17, 2023
dcc436f
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 18, 2023
441a9b9
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 19, 2023
9a08997
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 21, 2023
d426e96
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 22, 2023
a533793
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 23, 2023
244710d
del wrong pictire
Jorjeous May 24, 2023
2ba9c17
checkpout from main
Jorjeous May 24, 2023
75baebf
fix scale
Jorjeous May 24, 2023
43249dc
wiped out torchmetrics
Jorjeous May 24, 2023
16872d4
Merge branch 'main' into SDE_unt_lvl
Jorjeous May 24, 2023
bf74f2b
switch to editdist
Jorjeous Jun 2, 2023
05a9c0b
Merge branch 'SDE_unt_lvl' of github.com:NVIDIA/NeMo into SDE_unt_lvl
Jorjeous Jun 2, 2023
d8dd174
Merge branch 'main' into SDE_unt_lvl
Jorjeous Jun 2, 2023
217fcaa
Merge branch 'SDE_unt_lvl' of github.com:NVIDIA/NeMo into SDE_unt_lvl
Jorjeous Jun 2, 2023
7f135a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
4416408
rm import
Jorjeous Jun 2, 2023
8c1a1a2
fix
Jorjeous Jun 2, 2023
a89adcd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
42b96e9
rm imports
Jorjeous Jun 2, 2023
691e625
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
8cf0fc7
rm comments
Jorjeous Jun 2, 2023
085b19b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
9b75822
Merge branch 'SDE_unt_lvl' of github.com:NVIDIA/NeMo into SDE_unt_lvl
Jorjeous Jun 2, 2023
0d61fc8
Merge branch 'main' into SDE_unt_lvl
vsl9 Jun 2, 2023
8f3acc4
Merge branch 'main' into SDE_unt_lvl
vsl9 Jun 2, 2023
d2d5c6a
Merge branch 'main' into SDE_unt_lvl
vsl9 Jun 2, 2023
9c72957
Merge branch 'main' into SDE_unt_lvl
vsl9 Jun 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/collections/asr/test_asr_local_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@ class TestASRLocalAttention:
@pytest.mark.with_downloads()
@pytest.mark.unit
def test_forward(self):
asr_model = ASRModel.from_pretrained("stt_en_conformer_ctc_small")
asr_model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") # restore from
asr_model = asr_model.eval()

len = 16000 * 60 * 30 # 30 minutes, OOM without local attention
len = 16000 * 60 * 30 # 30 minutes, OOM without local attention #max len 3 5
input_signal_long = torch.randn(size=(1, len), device=asr_model.device)
length_long = torch.tensor([len], device=asr_model.device)

# switch to local attn
asr_model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=(64, 64))
# asr_model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=(64, 64))
with torch.no_grad():
asr_model.forward(input_signal=input_signal_long, input_signal_length=length_long)

# switch context size only (keep local)
asr_model.change_attention_model(att_context_size=(192, 192))
# asr_model.change_attention_model(att_context_size=(192, 192))
with torch.no_grad():
asr_model.forward(input_signal=input_signal_long, input_signal_length=length_long)

Expand All @@ -57,8 +57,8 @@ def test_forward(self):
def test_change_save_restore(self):

model = ASRModel.from_pretrained("stt_en_conformer_ctc_small")
model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=(64, 64))
attr_for_eq_check = ["encoder.self_attention_model", "encoder.att_context_size"]
# model.change_attention_model(self_attention_model="rel_pos_local_attn", att_context_size=(64, 64))
# attr_for_eq_check = ["encoder.self_attention_model", "encoder.att_context_size"]

with tempfile.TemporaryDirectory() as restore_folder:
with tempfile.TemporaryDirectory() as save_folder:
Expand Down
Loading
Loading