Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 14, 2021
1 parent 2abcf14 commit d548fb3
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 6 deletions.
24 changes: 24 additions & 0 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
{"sentence": "this is a sentence three","lab":0}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"sentence": "this is a sentence one","lab":0},
{"sentence": "this is a sentence two","lab":1},
{"sentence": "this is a sentence three","lab":0}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -57,6 +63,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -99,6 +111,18 @@ def test_from_json(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = TextClassificationData.from_json(
"sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand Down
24 changes: 24 additions & 0 deletions tests/text/seq2seq/question_answering/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
{"input": "this is a question three","target":"this is an answer three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a question one","target":"this is an answer one"},
{"input": "this is a question two","target":"this is an answer two"},
{"input": "this is a question three","target":"this is an answer three"}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -106,3 +118,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = QuestionAnsweringData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
36 changes: 30 additions & 6 deletions tests/text/seq2seq/summarization/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,21 @@
TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing

TEST_CSV_DATA = """input,target
this is a sentence one,this is a translated sentence one
this is a sentence two,this is a translated sentence two
this is a sentence three,this is a translated sentence three
this is a sentence one,this is a summarized sentence one
this is a sentence two,this is a summarized sentence two
this is a sentence three,this is a summarized sentence three
"""

TEST_JSON_DATA = """
{"input": "this is a sentence one","target":"this is a translated sentence one"}
{"input": "this is a sentence two","target":"this is a translated sentence two"}
{"input": "this is a sentence three","target":"this is a translated sentence three"}
{"input": "this is a sentence one","target":"this is a summarized sentence one"}
{"input": "this is a sentence two","target":"this is a summarized sentence two"}
{"input": "this is a sentence three","target":"this is a summarized sentence three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a sentence one","target":"this is a summarized sentence one"},
{"input": "this is a sentence two","target":"this is a summarized sentence two"},
{"input": "this is a sentence three","target":"this is a summarized sentence three"}]}
"""


Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -106,3 +118,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = SummarizationData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
24 changes: 24 additions & 0 deletions tests/text/seq2seq/translation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
{"input": "this is a sentence three","target":"this is a translated sentence three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a sentence one","target":"this is a translated sentence one"},
{"input": "this is a sentence two","target":"this is a translated sentence two"},
{"input": "this is a sentence three","target":"this is a translated sentence three"}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -86,3 +98,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = TranslationData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch

0 comments on commit d548fb3

Please sign in to comment.