Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

QuestionAnsweringTask and QuestionAnsweringData with SQuADDatSource. #607

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9679a31
Added SQuADDatSource class and an example for the QuestionAnsweringTask
karthikrangasai Jul 20, 2021
29ad4e7
Fixed issues with the SQuADDataSource class, and updated the example
karthikrangasai Jul 21, 2021
0a76e30
Refactoring code and adding test for example
karthikrangasai Jul 21, 2021
cb3a2d5
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Jul 21, 2021
e5b268a
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Jul 21, 2021
2845cd2
Updates
ethanwharris Jul 21, 2021
860c12e
Extracting Question Answering Task out from Seq2Seq tasks
karthikrangasai Jul 21, 2021
fba4043
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Jul 28, 2021
664f3d7
Added code for pre-processing the data
karthikrangasai Jul 28, 2021
1a78e4e
Refactored tests and updated imports
karthikrangasai Jul 29, 2021
f250335
Updated Question Answering Data and Task to handle the AutoModelForQu…
karthikrangasai Aug 17, 2021
2fa8d5e
Updating question answering example
karthikrangasai Aug 17, 2021
d74aec2
Updating question answering example
karthikrangasai Aug 17, 2021
c05e54e
Update references, add documentation, and update documentation refere…
karthikrangasai Aug 17, 2021
f503d43
Added tests for data and model
karthikrangasai Aug 17, 2021
c9216c2
Update to QuestionAnsweringData - fixing errors and updating tests ex…
karthikrangasai Aug 22, 2021
2e0fc64
Add Deserializer to fix test_serve and remove test_jit as it not comp…
karthikrangasai Aug 23, 2021
c76323a
remove deleted class' import
karthikrangasai Aug 23, 2021
6a14cd6
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Aug 23, 2021
4e7d28c
Update files after master merge.
karthikrangasai Aug 23, 2021
01571c6
Add QnA CLI, update docs, add ORT callback for QnA task.
karthikrangasai Aug 23, 2021
bc2feff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
61501fa
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Aug 26, 2021
adbb4aa
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Aug 26, 2021
c149c37
Update device type issues present in _generate_answers method.
karthikrangasai Aug 27, 2021
a8b24d6
Merge branch 'master' into feature/577_add_SQuADDataSource
mergify[bot] Aug 31, 2021
37c1030
Changes from code review.
karthikrangasai Aug 31, 2021
28f282e
Add to CLI and test
ethanwharris Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ flash_examples/checkpoints
timit/
urban8k_images/
__MACOSX
*-v2.0.json
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `in_chans` argument to the flash ResNet to control the expected number of input channels ([#673](https://github.com/PyTorchLightning/lightning-flash/pull/673))

- Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
15 changes: 12 additions & 3 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,19 @@ __________________
:nosignatures:
:template: classtemplate.rst

~seq2seq.question_answering.model.QuestionAnsweringTask
~seq2seq.question_answering.data.QuestionAnsweringData
~question_answering.model.QuestionAnsweringTask
~question_answering.data.QuestionAnsweringData

question_answering.data.QuestionAnsweringBackboneState
question_answering.data.QuestionAnsweringCSVDataSource
question_answering.data.QuestionAnsweringDataSource
question_answering.data.QuestionAnsweringDictionaryDataSource
question_answering.data.QuestionAnsweringFileDataSource
question_answering.data.QuestionAnsweringJSONDataSource
question_answering.data.QuestionAnsweringPostprocess
question_answering.data.QuestionAnsweringPreprocess
question_answering.data.SQuADDataSource

seq2seq.question_answering.data.QuestionAnsweringPreprocess

Summarization
_____________
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Lightning Flash

reference/text_classification
reference/text_classification_multi_label
reference/question_answering
reference/summarization
reference/translation

Expand Down
77 changes: 77 additions & 0 deletions docs/source/reference/question_answering.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _question_answering:

##################
Question Answering
##################

********
The Task
********

Question Answering is the task of being able to answer questions pertaining to some known context.
For example, given a context about some historical figure, any question pertaininig to the context should be answerable.
In our case the article would be our input context and question, and the answer would be the output sequence from the model.

.. note::

We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets.

-----

*******
Example
*******

Let's look at an example.
We'll use the SQUAD 2.0 dataset, which contains ``train-v2.0.json`` and ``dev-v2.0.json``.
Each JSON file looks like this:

.. code-block::

{
"answers": {
"answer_start": [94, 87, 94, 94],
"text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"]
},
"context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...",
"id": "56ddde6b9a695914005b9629",
"question": "When were the Normans in Normandy?",
"title": "Normans"
}
...

In the above, the ``context`` key represents the context used for the question and answer, the ``question`` key represents the question being asked with respect to the context, the ``answer`` key stores the answer(s) for the question.
``id`` and ``title`` are used for unique identification and grouping concepts together respectively.
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.text.question_answering.data.QuestionAnsweringData`.
We select a pre-trained backbone to use for our :class:`~flash.text.question_answering.model.QuestionAnsweringTask` and finetune on the SQUAD 2.0 data.
The backbone can be any Question Answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_.

.. note::

When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.text.question_answering.data.QuestionAnsweringData` and the :class:`~flash.text.question_answering.model.QuestionAnsweringTask`!

Next, we use the trained :class:`~flash.text.question_answering.model.QuestionAnsweringTask` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/question_answering.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``QuestionAnsweringTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python

...

model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True)
1 change: 1 addition & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def wrapper(cli_args):
"flash.pointcloud.segmentation",
"flash.tabular.classification",
"flash.text.classification",
"flash.text.question_answering",
"flash.text.seq2seq.summarization",
"flash.text.seq2seq.translation",
"flash.video.classification",
Expand Down
3 changes: 1 addition & 2 deletions flash/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
QuestionAnsweringData,
QuestionAnsweringTask,
Seq2SeqData,
Seq2SeqTask,
SummarizationData,
Expand Down
2 changes: 2 additions & 0 deletions flash/text/question_answering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.text.question_answering.data import QuestionAnsweringData # noqa: F401
from flash.text.question_answering.model import QuestionAnsweringTask # noqa: F401
58 changes: 58 additions & 0 deletions flash/text/question_answering/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

__all__ = ["question_answering"]


def from_squad(
backbone: str = "distilbert-base-uncased",
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs,
) -> QuestionAnsweringData:
"""Downloads and loads the XSum data set."""
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/")
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/")
return QuestionAnsweringData.from_squad_v2(
train_file="./data/train-v2.0.json",
val_file="./data/dev-v2.0.json",
backbone=backbone,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)


def question_answering():
"""Extractive Question Answering."""
cli = FlashCLI(
QuestionAnsweringTask,
QuestionAnsweringData,
default_datamodule_builder=from_squad,
default_arguments={
"trainer.max_epochs": 3,
"model.backbone": "distilbert-base-uncased",
},
)

cli.trainer.save_checkpoint("question_answering_model.pt")


if __name__ == "__main__":
question_answering()
Loading