Skip to content

Commit

Permalink
add checks (NVIDIA#7943)
Browse files Browse the repository at this point in the history
Signed-off-by: eharper <[email protected]>
Signed-off-by: Sasha Meister <[email protected]>
  • Loading branch information
ericharper authored and sashameister committed Feb 15, 2024
1 parent d365087 commit 3b54d8d
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions tests/collections/asr/decoding/rnnt_alignments_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# these tests outside of the CI machines environment, where test data is
# stored

import os
import pytest
from examples.asr.transcribe_speech import TranscriptionConfig
from omegaconf import OmegaConf
Expand Down Expand Up @@ -68,6 +69,7 @@ def cleanup_local_folder():


# TODO: add the same tests for multi-blank RNNT decoding
@pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine')
def test_rnnt_alignments():
# using greedy as baseline and comparing all other configurations to it
ref_transcriptions = get_rnnt_alignments("greedy")
Expand Down
1 change: 1 addition & 0 deletions tests/collections/nlp/test_chat_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def create_data_points(mask_user, turn_num, records, temp_file, t2v, label=True)
return data_points


@pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine')
class TestGPTSFTChatDataset:
@classmethod
def setup_class(cls):
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/nlp/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_get_model(self):
out = model.forward(*inp)
typecheck.set_typecheck_enabled(enabled=True)

@pytest.mark.skipif(not os.path.exists('/home/TestData/nlp'), reason='Not a Jenkins machine')
@pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine')
@pytest.mark.with_downloads()
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/nlp/test_nmt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_train_eval_loss(self):
eval_loss = model.eval_loss_fn(log_probs=log_probs, labels=tgt_ids)
assert torch.allclose(train_loss, eval_loss)

@pytest.mark.skipif(not os.path.exists('/home/TestData/nlp'), reason='Not a Jenkins machine')
@pytest.mark.skipif(not os.path.exists('/home/TestData'), reason='Not a Jenkins machine')
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_gpu_export_ts(self):
Expand Down

0 comments on commit 3b54d8d

Please sign in to comment.