This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
QuestionAnsweringTask and QuestionAnsweringData with SQuADDatSource. (#…
…607) Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
c512c31
commit 5b94abf
Showing
22 changed files
with
1,847 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,3 +167,4 @@ flash_examples/checkpoints | |
timit/ | ||
urban8k_images/ | ||
__MACOSX | ||
*-v2.0.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
.. _question_answering: | ||
|
||
################## | ||
Question Answering | ||
################## | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
Question Answering is the task of being able to answer questions pertaining to some known context. | ||
For example, given a context about some historical figure, any question pertaininig to the context should be answerable. | ||
In our case the article would be our input context and question, and the answer would be the output sequence from the model. | ||
|
||
.. note:: | ||
|
||
We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets. | ||
|
||
----- | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at an example. | ||
We'll use the SQUAD 2.0 dataset, which contains ``train-v2.0.json`` and ``dev-v2.0.json``. | ||
Each JSON file looks like this: | ||
|
||
.. code-block:: | ||
{ | ||
"answers": { | ||
"answer_start": [94, 87, 94, 94], | ||
"text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"] | ||
}, | ||
"context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...", | ||
"id": "56ddde6b9a695914005b9629", | ||
"question": "When were the Normans in Normandy?", | ||
"title": "Normans" | ||
} | ||
... | ||
In the above, the ``context`` key represents the context used for the question and answer, the ``question`` key represents the question being asked with respect to the context, the ``answer`` key stores the answer(s) for the question. | ||
``id`` and ``title`` are used for unique identification and grouping concepts together respectively. | ||
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.text.question_answering.data.QuestionAnsweringData`. | ||
We select a pre-trained backbone to use for our :class:`~flash.text.question_answering.model.QuestionAnsweringTask` and finetune on the SQUAD 2.0 data. | ||
The backbone can be any Question Answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_. | ||
|
||
.. note:: | ||
|
||
When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.text.question_answering.data.QuestionAnsweringData` and the :class:`~flash.text.question_answering.model.QuestionAnsweringTask`! | ||
|
||
Next, we use the trained :class:`~flash.text.question_answering.model.QuestionAnsweringTask` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/question_answering.py | ||
:language: python | ||
:lines: 14- | ||
|
||
------ | ||
|
||
********************************************** | ||
Accelerate Training & Inference with Torch ORT | ||
********************************************** | ||
|
||
`Torch ORT <https://cloudblogs.microsoft.com/opensource/2021/07/13/accelerate-pytorch-training-with-torch-ort/>`__ converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the ``QuestionAnsweringTask`` once installed. See installation instructions `here <https://github.com/pytorch/ort#install-in-a-local-python-environment>`__. | ||
|
||
.. note:: | ||
|
||
Not all Transformer models are supported. See `this table <https://github.com/microsoft/onnxruntime-training-examples#examples>`__ for supported models + branches containing fixes for certain models. | ||
|
||
.. code-block:: python | ||
... | ||
model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from flash.text.question_answering.data import QuestionAnsweringData # noqa: F401 | ||
from flash.text.question_answering.model import QuestionAnsweringTask # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional | ||
|
||
from flash.core.data.utils import download_data | ||
from flash.core.utilities.flash_cli import FlashCLI | ||
from flash.text import QuestionAnsweringData, QuestionAnsweringTask | ||
|
||
__all__ = ["question_answering"] | ||
|
||
|
||
def from_squad( | ||
backbone: str = "distilbert-base-uncased", | ||
batch_size: int = 4, | ||
num_workers: Optional[int] = None, | ||
**preprocess_kwargs, | ||
) -> QuestionAnsweringData: | ||
"""Downloads and loads the XSum data set.""" | ||
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/") | ||
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/") | ||
return QuestionAnsweringData.from_squad_v2( | ||
train_file="./data/train-v2.0.json", | ||
val_file="./data/dev-v2.0.json", | ||
backbone=backbone, | ||
batch_size=batch_size, | ||
num_workers=num_workers, | ||
**preprocess_kwargs, | ||
) | ||
|
||
|
||
def question_answering(): | ||
"""Extractive Question Answering.""" | ||
cli = FlashCLI( | ||
QuestionAnsweringTask, | ||
QuestionAnsweringData, | ||
default_datamodule_builder=from_squad, | ||
default_arguments={ | ||
"trainer.max_epochs": 3, | ||
"model.backbone": "distilbert-base-uncased", | ||
}, | ||
) | ||
|
||
cli.trainer.save_checkpoint("question_answering_model.pt") | ||
|
||
|
||
if __name__ == "__main__": | ||
question_answering() |
Oops, something went wrong.