Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed May 28, 2021
1 parent 5e6f7be commit b4bfa31
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@

from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text import TextClassificationData
from flash.text.classification.data import (
TextCSVDataSource,
TextDataSource,
TextFileDataSource,
TextJSONDataSource,
TextSentencesDataSource,
)

if _TEXT_AVAILABLE:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing

Expand Down Expand Up @@ -92,3 +102,32 @@ def test_from_json(tmpdir):
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file="", batch_size=1)


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.")
@pytest.mark.parametrize(
"cls, kwargs",
[
(TextDataSource, {}),
(TextFileDataSource, {
"filetype": "csv"
}),
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextSentencesDataSource, {}),
],
)
def test_tokenizer_state(cls, kwargs):
"""Tests that the tokenizer is not in __getstate__"""
instance = cls(backbone="sshleifer/tiny-mbart", **kwargs)
state = instance.__getstate__()
tokenizers = []
for name, attribute in instance.__dict__.items():
if isinstance(attribute, PreTrainedTokenizerBase):
assert name not in state
setattr(instance, name, None)
tokenizers.append(name)
instance.__setstate__(state)
for name in tokenizers:
assert getattr(instance, name, None) is not None
74 changes: 74 additions & 0 deletions tests/text/seq2seq/core/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path

import pytest

from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text import TextClassificationData
from flash.text.seq2seq.core.data import (
Seq2SeqBackboneState,
Seq2SeqCSVDataSource,
Seq2SeqDataSource,
Seq2SeqFileDataSource,
Seq2SeqJSONDataSource,
Seq2SeqPostprocess,
Seq2SeqSentencesDataSource,
)

if _TEXT_AVAILABLE:
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.")
@pytest.mark.parametrize(
"cls, kwargs",
[
(Seq2SeqDataSource, {
"backbone": "sshleifer/tiny-mbart"
}),
(Seq2SeqFileDataSource, {
"backbone": "sshleifer/tiny-mbart",
"filetype": "csv"
}),
(Seq2SeqCSVDataSource, {
"backbone": "sshleifer/tiny-mbart"
}),
(Seq2SeqJSONDataSource, {
"backbone": "sshleifer/tiny-mbart"
}),
(Seq2SeqSentencesDataSource, {
"backbone": "sshleifer/tiny-mbart"
}),
(Seq2SeqPostprocess, {}),
],
)
def test_tokenizer_state(cls, kwargs):
"""Tests that the tokenizer is not in __getstate__"""
process_state = Seq2SeqBackboneState(backbone="sshleifer/tiny-mbart")
instance = cls(**kwargs)
instance.set_state(process_state)
getattr(instance, "tokenizer", None)
state = instance.__getstate__()
tokenizers = []
for name, attribute in instance.__dict__.items():
if isinstance(attribute, PreTrainedTokenizerBase):
assert name not in state
setattr(instance, name, None)
tokenizers.append(name)
instance.__setstate__(state)
for name in tokenizers:
assert getattr(instance, name, None) is not None

0 comments on commit b4bfa31

Please sign in to comment.