diff --git a/flash/text/__init__.py b/flash/text/__init__.py index 8ac71bdfb5..5a25ab337e 100644 --- a/flash/text/__init__.py +++ b/flash/text/__init__.py @@ -1,5 +1,6 @@ from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401 from flash.text.seq2seq import ( # noqa: F401 + QuestionAnsweringData, Seq2SeqData, Seq2SeqTask, SummarizationData, diff --git a/flash/text/seq2seq/__init__.py b/flash/text/seq2seq/__init__.py index 1c30bc9d85..8dd7ad1ebb 100644 --- a/flash/text/seq2seq/__init__.py +++ b/flash/text/seq2seq/__init__.py @@ -1,3 +1,4 @@ from flash.text.seq2seq.core import Seq2SeqData, Seq2SeqFreezeEmbeddings, Seq2SeqTask # noqa: F401 +from flash.text.seq2seq.question_answering import QuestionAnsweringData # noqa: F401 from flash.text.seq2seq.summarization import SummarizationData, SummarizationTask # noqa: F401 from flash.text.seq2seq.translation import TranslationData, TranslationTask # noqa: F401 diff --git a/flash/text/seq2seq/question_answering/__init__.py b/flash/text/seq2seq/question_answering/__init__.py new file mode 100644 index 0000000000..7892b34432 --- /dev/null +++ b/flash/text/seq2seq/question_answering/__init__.py @@ -0,0 +1 @@ +from flash.text.seq2seq.question_answering.data import QuestionAnsweringData # noqa: F401 diff --git a/flash/text/seq2seq/question_answering/data.py b/flash/text/seq2seq/question_answering/data.py new file mode 100644 index 0000000000..b3d42662a5 --- /dev/null +++ b/flash/text/seq2seq/question_answering/data.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Optional, Union + +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess + + +class QuestionAnsweringPreprocess(Seq2SeqPreprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + backbone: str = "t5-small", + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length' + ): + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + backbone=backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ) + + +class QuestionAnsweringData(Seq2SeqData): + + preprocess_cls = QuestionAnsweringPreprocess + postprocess_cls = Seq2SeqPostprocess diff --git a/tests/text/seq2seq/question_answering/__init__.py b/tests/text/seq2seq/question_answering/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/seq2seq/question_answering/test_data.py b/tests/text/seq2seq/question_answering/test_data.py new file mode 100644 index 0000000000..2db170464e --- /dev/null +++ b/tests/text/seq2seq/question_answering/test_data.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest + +from flash.text import QuestionAnsweringData +from tests.helpers.utils import _TEXT_TESTING + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + +TEST_CSV_DATA = """input,target +this is a question one,this is an answer one +this is a question two,this is an answer two +this is a question three,this is an answer three +""" + +TEST_JSON_DATA = """ +{"input": "this is a question one","target":"this is an answer one"} +{"input": "this is a question two","target":"this is an answer two"} +{"input": "this is a question three","target":"this is an answer three"} +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA) + return path + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_csv(tmpdir): + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_files(tmpdir): + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv( + "input", + "target", + backbone=TEST_BACKBONE, + train_file=csv_path, + val_file=csv_path, + test_file=csv_path, + batch_size=1, + ) + batch = next(iter(dm.val_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + batch = next(iter(dm.test_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_postprocess_tokenizer(tmpdir): + """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is + used. + """ + backbone = "sshleifer/bart-tiny-random" + csv_path = csv_data(tmpdir) + dm = QuestionAnsweringData.from_csv( + "input", + "target", + backbone=backbone, + train_file=csv_path, + batch_size=1, + ) + pipeline = dm.data_pipeline + pipeline.initialize() + assert pipeline._postprocess_pipeline.backbone == backbone + assert pipeline._postprocess_pipeline.tokenizer is not None + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json(tmpdir): + json_path = json_data(tmpdir) + dm = QuestionAnsweringData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch