Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
QuestionAnsweringTask and QuestionAnsweringData with SQuADDatSource. (#…
Browse files Browse the repository at this point in the history
…607)

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Aug 31, 2021
1 parent c512c31 commit 5b94abf
Show file tree
Hide file tree
Showing 22 changed files with 1,847 additions and 308 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ flash_examples/checkpoints
timit/
urban8k_images/
__MACOSX
*-v2.0.json
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `in_chans` argument to the flash ResNet to control the expected number of input channels ([#673](https://github.com/PyTorchLightning/lightning-flash/pull/673))

- Added a `QuestionAnswering` task for extractive question answering ([#607](https://github.com/PyTorchLightning/lightning-flash/pull/607))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
15 changes: 12 additions & 3 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,19 @@ __________________
:nosignatures:
:template: classtemplate.rst

~seq2seq.question_answering.model.QuestionAnsweringTask
~seq2seq.question_answering.data.QuestionAnsweringData
~question_answering.model.QuestionAnsweringTask
~question_answering.data.QuestionAnsweringData

question_answering.data.QuestionAnsweringBackboneState
question_answering.data.QuestionAnsweringCSVDataSource
question_answering.data.QuestionAnsweringDataSource
question_answering.data.QuestionAnsweringDictionaryDataSource
question_answering.data.QuestionAnsweringFileDataSource
question_answering.data.QuestionAnsweringJSONDataSource
question_answering.data.QuestionAnsweringPostprocess
question_answering.data.QuestionAnsweringPreprocess
question_answering.data.SQuADDataSource

seq2seq.question_answering.data.QuestionAnsweringPreprocess

Summarization
_____________
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Lightning Flash

reference/text_classification
reference/text_classification_multi_label
reference/question_answering
reference/summarization
reference/translation

Expand Down
77 changes: 77 additions & 0 deletions docs/source/reference/question_answering.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _question_answering:

##################
Question Answering
##################

********
The Task
********

Question Answering is the task of being able to answer questions pertaining to some known context.
For example, given a context about some historical figure, any question pertaininig to the context should be answerable.
In our case the article would be our input context and question, and the answer would be the output sequence from the model.

.. note::

We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets.

-----

*******
Example
*******

Let's look at an example.
We'll use the SQUAD 2.0 dataset, which contains ``train-v2.0.json`` and ``dev-v2.0.json``.
Each JSON file looks like this:

.. code-block::
{
"answers": {
"answer_start": [94, 87, 94, 94],
"text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"]
},
"context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...",
"id": "56ddde6b9a695914005b9629",
"question": "When were the Normans in Normandy?",
"title": "Normans"
}
...
In the above, the ``context`` key represents the context used for the question and answer, the ``question`` key represents the question being asked with respect to the context, the ``answer`` key stores the answer(s) for the question.
``id`` and ``title`` are used for unique identification and grouping concepts together respectively.
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.text.question_answering.data.QuestionAnsweringData`.
We select a pre-trained backbone to use for our :class:`~flash.text.question_answering.model.QuestionAnsweringTask` and finetune on the SQUAD 2.0 data.
The backbone can be any Question Answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_.

.. note::

When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.text.question_answering.data.QuestionAnsweringData` and the :class:`~flash.text.question_answering.model.QuestionAnsweringTask`!

Next, we use the trained :class:`~flash.text.question_answering.model.QuestionAnsweringTask` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/question_answering.py
:language: python
:lines: 14-

------

**********************************************
Accelerate Training & Inference with Torch ORT
**********************************************

`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``QuestionAnsweringTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__.

.. note::

Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models.

.. code-block:: python
...
model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True)
1 change: 1 addition & 0 deletions flash/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def wrapper(cli_args):
"flash.pointcloud.segmentation",
"flash.tabular.classification",
"flash.text.classification",
"flash.text.question_answering",
"flash.text.seq2seq.summarization",
"flash.text.seq2seq.translation",
"flash.video.classification",
Expand Down
3 changes: 1 addition & 2 deletions flash/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
QuestionAnsweringData,
QuestionAnsweringTask,
Seq2SeqData,
Seq2SeqTask,
SummarizationData,
Expand Down
2 changes: 2 additions & 0 deletions flash/text/question_answering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.text.question_answering.data import QuestionAnsweringData # noqa: F401
from flash.text.question_answering.model import QuestionAnsweringTask # noqa: F401
58 changes: 58 additions & 0 deletions flash/text/question_answering/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

__all__ = ["question_answering"]


def from_squad(
backbone: str = "distilbert-base-uncased",
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs,
) -> QuestionAnsweringData:
"""Downloads and loads the XSum data set."""
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/")
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/")
return QuestionAnsweringData.from_squad_v2(
train_file="./data/train-v2.0.json",
val_file="./data/dev-v2.0.json",
backbone=backbone,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)


def question_answering():
"""Extractive Question Answering."""
cli = FlashCLI(
QuestionAnsweringTask,
QuestionAnsweringData,
default_datamodule_builder=from_squad,
default_arguments={
"trainer.max_epochs": 3,
"model.backbone": "distilbert-base-uncased",
},
)

cli.trainer.save_checkpoint("question_answering_model.pt")


if __name__ == "__main__":
question_answering()
Loading

0 comments on commit 5b94abf

Please sign in to comment.