diff --git a/CHANGELOG.md b/CHANGELOG.md index effbc3ea36..5bb6e71f1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `in_chans` argument to the flash ResNet to control the expected number of input channels ([#673](https://github.com/PyTorchLightning/lightning-flash/pull/673)) +- Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index af0dd87500..3403afd39b 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -41,14 +41,14 @@ __________________ ~question_answering.data.QuestionAnsweringData question_answering.data.QuestionAnsweringBackboneState - question_answering.data.QuestionAnsweringCSVDataSource - question_answering.data.QuestionAnsweringDataSource - question_answering.data.QuestionAnsweringDictionaryDataSource - question_answering.data.QuestionAnsweringFileDataSource - question_answering.data.QuestionAnsweringJSONDataSource - question_answering.data.QuestionAnsweringPostprocess - question_answering.data.QuestionAnsweringPreprocess - question_answering.data.SQuADDataSource + question_answering.data.QuestionAnsweringCSVDataSource + question_answering.data.QuestionAnsweringDataSource + question_answering.data.QuestionAnsweringDictionaryDataSource + question_answering.data.QuestionAnsweringFileDataSource + question_answering.data.QuestionAnsweringJSONDataSource + question_answering.data.QuestionAnsweringPostprocess + question_answering.data.QuestionAnsweringPreprocess + question_answering.data.SQuADDataSource Summarization diff --git a/docs/source/reference/question_answering.rst b/docs/source/reference/question_answering.rst index 5569f76894..3030840b82 100644 --- a/docs/source/reference/question_answering.rst +++ b/docs/source/reference/question_answering.rst @@ -40,8 +40,8 @@ Each JSON file looks like this: } ... -In the above, the context key represents the context used for the question and answer, the question key represents the question being asked with respect to the context, the answer key stores the answer(s) for the question. -id and title are used for unique identification and grouping concepts together respectively. +In the above, the ``context`` key represents the context used for the question and answer, the ``question`` key represents the question being asked with respect to the context, the ``answer`` key stores the answer(s) for the question. +``id`` and ``title`` are used for unique identification and grouping concepts together respectively. Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.text.question_answering.data.QuestionAnsweringData`. We select a pre-trained backbone to use for our :class:`~flash.text.question_answering.model.QuestionAnsweringTask` and finetune on the SQUAD 2.0 data. The backbone can be any Question Answering model from `HuggingFace/transformers `_. diff --git a/flash/text/question_answering/cli.py b/flash/text/question_answering/cli.py index e89769d6a8..12932ae930 100644 --- a/flash/text/question_answering/cli.py +++ b/flash/text/question_answering/cli.py @@ -51,7 +51,7 @@ def question_answering(): }, ) - cli.trainer.save_checkpoint("question_answering_on_sqaud_v2.pt") + cli.trainer.save_checkpoint("question_answering_model.pt") if __name__ == "__main__": diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 41e8fc85a5..58b8889350 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -31,7 +31,7 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.process import Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras @@ -41,109 +41,6 @@ from transformers import AutoTokenizer, DataCollatorWithPadding, default_data_collator -class QuestionAnsweringDeserializer(Deserializer): - @requires_extras("text") - def __init__( - self, - backbone: str, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", - question_column_name: str = "question", - context_column_name: str = "context", - answer_column_name: str = "answer", - doc_stride: int = 128, - ): - super().__init__() - self.backbone = backbone - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) - self.max_source_length = max_source_length - self.max_target_length = max_target_length - self.padding = padding - - # Setup global pre-processing requirements - self._question_column_name = question_column_name - self._context_column_name = context_column_name - self._answer_column_name = answer_column_name - self._doc_stride = doc_stride - - def _prepare_features(self, sample: Any, tokenized_sample: Any): - # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the - # corresponding example_id and we will store the offset mappings. - tokenized_sample["example_id"] = [] - tokenized_sample["context"] = [] - tokenized_sample["answer"] = [] - - for i in range(len(tokenized_sample["input_ids"])): - # Grab the sequence corresponding to that example (to know what is the context and what is the question). - sequence_ids = tokenized_sample.sequence_ids(i) - context_index = 1 if self.pad_on_right else 0 - - # One example can give several spans, this is the index of the example containing this span of text. - tokenized_sample["example_id"].append(sample["id"]) - tokenized_sample["context"].append(sample["context"]) - if self._running_stage.evaluating: - tokenized_sample["answer"].append(sample["answer"]) - - # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token - # position is part of the context or not. - tokenized_sample["offset_mapping"][i] = [ - (o if sequence_ids[k] == context_index else None) - for k, o in enumerate(tokenized_sample["offset_mapping"][i]) - ] - - return tokenized_sample - - def deserialize(self, sample: Dict[str, Any]) -> Tensor: - sample[self.question_column_name] = sample[self.question_column_name].lstrip() - self.pad_on_right = self.tokenizer.padding_side == "right" - tokenized_sample = self.tokenizer( - sample[self._question_column_name if self.pad_on_right else self._context_column_name], - sample[self._context_column_name if self.pad_on_right else self._question_column_name], - truncation="only_second" if self.pad_on_right else "only_first", - max_length=self.max_source_length, - stride=self.doc_stride, - return_overflowing_tokens=True, - return_offsets_mapping=True, - padding=self.padding, - ) - tokenized_sample = self._prepare_val_features(sample, tokenized_sample) - offset_mappings = tokenized_sample.pop("offset_mapping") - example_ids = tokenized_sample.pop("example_id") - contexts = tokenized_sample.pop("context") - answers = tokenized_sample.pop("answer") - - tokenized_sample[DefaultDataKeys.METADATA] = [] - for offset_mapping, example_id, context in zip(offset_mappings, example_ids, contexts): - tokenized_sample[DefaultDataKeys.METADATA].append( - {"context": context, "offset_mapping": offset_mapping, "example_id": example_id} - ) - - del offset_mappings - del example_ids - del contexts - del answers - return tokenized_sample - - @property - def example_input(self) -> str: - return { - "id": "1", - "context": "this is answer one. this is context one", - "question": "this is question one", - "answer": {"text": ["this is answer one"], "answer_start": [0]}, - } - - def __getstate__(self): # TODO: Find out why this is being pickled - state = self.__dict__.copy() - state.pop("tokenizer") - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) - - class QuestionAnsweringDataSource(DataSource): @requires_extras("text") def __init__( @@ -166,6 +63,7 @@ def __init__( self.padding = padding # Setup global pre-processing requirements + self.pad_on_right = self.tokenizer.padding_side == "right" self._question_column_name = question_column_name self._context_column_name = context_column_name self._answer_column_name = answer_column_name @@ -176,8 +74,6 @@ def _tokenize_fn(self, samples: Any) -> Callable: samples[self.question_column_name] = [q.lstrip() for q in samples[self.question_column_name]] - self.pad_on_right = self.tokenizer.padding_side == "right" - tokenized_samples = self.tokenizer( samples[self._question_column_name if self.pad_on_right else self._context_column_name], samples[self._context_column_name if self.pad_on_right else self._question_column_name], @@ -656,16 +552,6 @@ def __init__( ), }, default_data_source="dict", - deserializer=QuestionAnsweringDeserializer( - backbone=backbone, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - question_column_name=question_column_name, - context_column_name=context_column_name, - answer_column_name=answer_column_name, - doc_stride=doc_stride, - ), ) self.set_state(QuestionAnsweringBackboneState(self.backbone)) @@ -879,7 +765,7 @@ def from_json( Examples:: - data_module = DataModule.from_json( + data_module = QuestionAnsweringData.from_json( train_file="train_data.json", train_transform={ "to_tensor_transform": torch.as_tensor, @@ -980,7 +866,7 @@ def from_csv( Examples:: - data_module = DataModule.from_csv( + data_module = QuestionAnsweringData.from_csv( "input", "target", train_file="train_data.csv", diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index 24874ad684..22366e02f8 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -193,11 +193,9 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples): if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) - # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, - # using the LogSumExp trick). + # Compute the softmax of all scores. scores: Tensor = torch.tensor([pred.pop("score") for pred in predictions]) - exp_scores: Tensor = torch.exp(scores - torch.max(scores)) - probs: Tensor = exp_scores / exp_scores.sum() + probs: Tensor = torch.softmax(scores, dim=0) # Include the probabilities in our predictions. for prob, pred in zip(probs, predictions): @@ -241,16 +239,17 @@ def training_step(self, batch: Any, batch_idx: int) -> Tensor: def common_step(self, prefix: str, batch: Any) -> torch.Tensor: generated_answers = self(batch) - self.compute_metrics(generated_answers, batch[DefaultDataKeys.METADATA], prefix) + result = self.compute_metrics(generated_answers, batch[DefaultDataKeys.METADATA]) + self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) - def compute_metrics(self, generated_tokens, batch, prefix): + def compute_metrics(self, generated_tokens, batch): for example in batch: predicted_answer = generated_tokens[example["example_id"]] target_answer = example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else "" self.rouge.update(predicted_answer, target_answer) result = self.rouge.compute() - self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) + return result def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("val", batch) diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py index 01c378f335..2381917318 100644 --- a/tests/text/question_answering/test_model.py +++ b/tests/text/question_answering/test_model.py @@ -13,7 +13,6 @@ # limitations under the License. import os import re -from unittest import mock import pytest import torch @@ -21,8 +20,7 @@ from flash import Trainer from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import QuestionAnsweringTask -from flash.text.question_answering.data import QuestionAnsweringPostprocess, QuestionAnsweringPreprocess -from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING +from tests.helpers.utils import _TEXT_TESTING # ======== Mock functions ======== @@ -56,17 +54,6 @@ def test_init_train(tmpdir): trainer.fit(model, train_dl) -@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") -@mock.patch("flash._IS_TESTING", True) -def test_serve(): - model = QuestionAnsweringTask(TEST_BACKBONE) - # TODO: Currently only servable once a preprocess and postprocess have been attached - model._preprocess = QuestionAnsweringPreprocess(backbone=TEST_BACKBONE) - model._postprocess = QuestionAnsweringPostprocess() - model.eval() - model.serve() - - @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")):