Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Feature/53x question answering task (#565)
Browse files Browse the repository at this point in the history
* Created QuestionAnsweringData and QuestionAnsweringPreprocess

* Added tests for the QuestionAnsweringData class

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
  • Loading branch information
karthikrangasai and ethanwharris authored Jul 12, 2021
1 parent 3071fea commit 48bdfd8
Showing 6 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions flash/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
QuestionAnsweringData,
Seq2SeqData,
Seq2SeqTask,
SummarizationData,
1 change: 1 addition & 0 deletions flash/text/seq2seq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flash.text.seq2seq.core import Seq2SeqData, Seq2SeqFreezeEmbeddings, Seq2SeqTask # noqa: F401
from flash.text.seq2seq.question_answering import QuestionAnsweringData # noqa: F401
from flash.text.seq2seq.summarization import SummarizationData, SummarizationTask # noqa: F401
from flash.text.seq2seq.translation import TranslationData, TranslationTask # noqa: F401
1 change: 1 addition & 0 deletions flash/text/seq2seq/question_answering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.text.seq2seq.question_answering.data import QuestionAnsweringData # noqa: F401
47 changes: 47 additions & 0 deletions flash/text/seq2seq/question_answering/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Optional, Union

from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class QuestionAnsweringPreprocess(Seq2SeqPreprocess):

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
backbone: str = "t5-small",
max_source_length: int = 128,
max_target_length: int = 128,
padding: Union[str, bool] = 'max_length'
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
backbone=backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
)


class QuestionAnsweringData(Seq2SeqData):

preprocess_cls = QuestionAnsweringPreprocess
postprocess_cls = Seq2SeqPostprocess
Empty file.
108 changes: 108 additions & 0 deletions tests/text/seq2seq/question_answering/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path

import pytest

from flash.text import QuestionAnsweringData
from tests.helpers.utils import _TEXT_TESTING

TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing

TEST_CSV_DATA = """input,target
this is a question one,this is an answer one
this is a question two,this is an answer two
this is a question three,this is an answer three
"""

TEST_JSON_DATA = """
{"input": "this is a question one","target":"this is an answer one"}
{"input": "this is a question two","target":"this is an answer two"}
{"input": "this is a question three","target":"this is an answer three"}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
path.write_text(TEST_CSV_DATA)
return path


def json_data(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
csv_path = csv_data(tmpdir)
dm = QuestionAnsweringData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_files(tmpdir):
csv_path = csv_data(tmpdir)
dm = QuestionAnsweringData.from_csv(
"input",
"target",
backbone=TEST_BACKBONE,
train_file=csv_path,
val_file=csv_path,
test_file=csv_path,
batch_size=1,
)
batch = next(iter(dm.val_dataloader()))
assert "labels" in batch
assert "input_ids" in batch

batch = next(iter(dm.test_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_postprocess_tokenizer(tmpdir):
"""Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is
used.
"""
backbone = "sshleifer/bart-tiny-random"
csv_path = csv_data(tmpdir)
dm = QuestionAnsweringData.from_csv(
"input",
"target",
backbone=backbone,
train_file=csv_path,
batch_size=1,
)
pipeline = dm.data_pipeline
pipeline.initialize()
assert pipeline._postprocess_pipeline.backbone == backbone
assert pipeline._postprocess_pipeline.tokenizer is not None


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json(tmpdir):
json_path = json_data(tmpdir)
dm = QuestionAnsweringData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch

0 comments on commit 48bdfd8

Please sign in to comment.