Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

QuestionAnsweringTask and QuestionAnsweringData with SQuADDatSource. #607

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9679a31
Added SQuADDatSource class and an example for the QuestionAnsweringTask
karthikrangasai Jul 20, 2021
29ad4e7
Fixed issues with the SQuADDataSource class, and updated the example
karthikrangasai Jul 21, 2021
0a76e30
Refactoring code and adding test for example
karthikrangasai Jul 21, 2021
cb3a2d5
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Jul 21, 2021
e5b268a
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Jul 21, 2021
2845cd2
Updates
ethanwharris Jul 21, 2021
860c12e
Extracting Question Answering Task out from Seq2Seq tasks
karthikrangasai Jul 21, 2021
fba4043
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Jul 28, 2021
664f3d7
Added code for pre-processing the data
karthikrangasai Jul 28, 2021
1a78e4e
Refactored tests and updated imports
karthikrangasai Jul 29, 2021
f250335
Updated Question Answering Data and Task to handle the AutoModelForQu…
karthikrangasai Aug 17, 2021
2fa8d5e
Updating question answering example
karthikrangasai Aug 17, 2021
d74aec2
Updating question answering example
karthikrangasai Aug 17, 2021
c05e54e
Update references, add documentation, and update documentation refere…
karthikrangasai Aug 17, 2021
f503d43
Added tests for data and model
karthikrangasai Aug 17, 2021
c9216c2
Update to QuestionAnsweringData - fixing errors and updating tests ex…
karthikrangasai Aug 22, 2021
2e0fc64
Add Deserializer to fix test_serve and remove test_jit as it not comp…
karthikrangasai Aug 23, 2021
c76323a
remove deleted class' import
karthikrangasai Aug 23, 2021
6a14cd6
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Aug 23, 2021
4e7d28c
Update files after master merge.
karthikrangasai Aug 23, 2021
01571c6
Add QnA CLI, update docs, add ORT callback for QnA task.
karthikrangasai Aug 23, 2021
bc2feff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2021
61501fa
Merge branch 'master' into feature/577_add_SQuADDataSource
karthikrangasai Aug 26, 2021
adbb4aa
Merge branch 'master' into feature/577_add_SQuADDataSource
ethanwharris Aug 26, 2021
c149c37
Update device type issues present in _generate_answers method.
karthikrangasai Aug 27, 2021
a8b24d6
Merge branch 'master' into feature/577_add_SQuADDataSource
mergify[bot] Aug 31, 2021
37c1030
Changes from code review.
karthikrangasai Aug 31, 2021
28f282e
Add to CLI and test
ethanwharris Aug 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor

import flash
Expand All @@ -28,7 +29,7 @@

if _TEXT_AVAILABLE:
import datasets
from datasets import DatasetDict, load_dataset
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoTokenizer, default_data_collator


Expand Down Expand Up @@ -222,6 +223,62 @@ def __setstate__(self, state):
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqDictionaryDataSource(Seq2SeqDataSource):

def _tokenize_fn(
self,
example: Dict[str, str],
input: Optional[str] = 'input',
input_pair: Optional[str] = None,
target: Optional[str] = None,
) -> Callable:

ex_input = example[input]
ex_input_pair = example[input_pair] if input_pair else None
ex_target = example[target] if target else None

model_inputs = self.tokenizer(
ex_input, ex_input_pair, max_length=self.max_source_length, padding=self.padding, truncation=True
)

# Setup the tokenizer for targets
if ex_target is not None:
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(
ex_target, max_length=self.max_target_length, padding=self.padding, truncation=True
)

model_inputs["labels"] = labels["input_ids"]

return model_inputs

def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset':
if columns is None:
columns = ["input_ids", "attention_mask", "labels"]
if self._running_stage.value == RunningStage.PREDICTING:
columns.remove("labels")

stage = self._running_stage.value
data, input, input_pair, target = data
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved

dataset_dict = DatasetDict({stage: Dataset.from_dict(data)})
dataset_dict = dataset_dict.map(
partial(self._tokenize_fn, input=input, input_pair=input_pair, target=target), batched=True
)

dataset_dict.set_format(columns=columns)
return dataset_dict[stage]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
"""The ``Seq2SeqBackboneState`` stores the backbone in use by the
Expand Down Expand Up @@ -274,6 +331,12 @@ def __init__(
max_target_length=max_target_length,
padding=padding,
),
"dict": Seq2SeqDictionaryDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
),
},
default_data_source="sentences",
deserializer=TextDeserializer(backbone, max_source_length)
Expand Down
203 changes: 196 additions & 7 deletions flash/text/seq2seq/question_answering/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,78 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Optional, Union
import json
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess
from torch import Tensor

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras
from flash.text.classification.data import TextDeserializer
from flash.text.seq2seq.core.data import (
Seq2SeqBackboneState,
Seq2SeqCSVDataSource,
Seq2SeqData,
Seq2SeqDictionaryDataSource,
Seq2SeqJSONDataSource,
Seq2SeqPostprocess,
Seq2SeqSentencesDataSource,
)

class QuestionAnsweringPreprocess(Seq2SeqPreprocess):
if _TEXT_AVAILABLE:
import datasets
from datasets import Dataset, DatasetDict
from transformers import default_data_collator


class SQuADDataSource(Seq2SeqDictionaryDataSource):

def load_data(self, data: str, dataset: Optional[Any] = None) -> 'datasets.Dataset':
stage = self._running_stage.value

file_path = data

path = Path(file_path)
with open(path, 'rb') as f:
squad_v_2_dict = json.load(f)

contexts = []
questions = []
answers = []
for topic in squad_v_2_dict['data']:
for comprehension in topic['paragraphs']:
context = comprehension['context']
for q_a_pair in comprehension['qas']:
question = q_a_pair['question']
for answer in q_a_pair['answers']:
answer_text = answer['text']

contexts.append(context)
questions.append(question)
answers.append(answer_text)

dataset_dict = DatasetDict({
stage: Dataset.from_dict({
"context": contexts,
"question": questions,
"answer": answers
})
})

dataset_dict = dataset_dict.map(
partial(self._tokenize_fn, input="question", input_pair="context", target="answer"), batched=True
)

return dataset_dict[stage]


class QuestionAnsweringPreprocess(Preprocess):

@requires_extras("text")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
Expand All @@ -29,19 +94,143 @@ def __init__(
max_target_length: int = 128,
padding: Union[str, bool] = 'max_length'
):
self.backbone = backbone
self.max_target_length = max_target_length
self.max_source_length = max_source_length
self.padding = padding

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
backbone=backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
data_sources={
DefaultDataSources.CSV: Seq2SeqCSVDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
),
DefaultDataSources.JSON: Seq2SeqJSONDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
),
"sentences": Seq2SeqSentencesDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
),
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
"dict": Seq2SeqDictionaryDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
),
"squad_v2": SQuADDataSource(
self.backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
)
},
# TODO: Change default here to Dictionary
default_data_source="dict",
deserializer=TextDeserializer(backbone, max_source_length)
)

self.set_state(Seq2SeqBackboneState(self.backbone))

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
"backbone": self.backbone,
"max_source_length": self.max_source_length,
"max_target_length": self.max_target_length,
"padding": self.padding,
}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):
return cls(**state_dict)

def collate(self, samples: Any) -> Tensor:
"""Override to convert a set of samples to a batch"""
return default_data_collator(samples)


class QuestionAnsweringData(Seq2SeqData):

preprocess_cls = QuestionAnsweringPreprocess
postprocess_cls = Seq2SeqPostprocess

@classmethod
def from_squad_v2(
cls,
train_file: Optional[str] = None,
val_file: Optional[str] = None,
test_file: Optional[str] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data
folders and corresponding target folders.

Args:
train_folder: The folder containing the train data.
train_ann_file: The COCO format annotation file.
val_folder: The folder containing the validation data.
val_ann_file: The COCO format annotation file.
test_folder: The folder containing the test data.
test_ann_file: The COCO format annotation file.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Returns:
The constructed data module.

Examples::

data_module = SemanticSegmentationData.from_coco(
train_folder="train_folder",
train_ann_file="annotations.json",
)
"""
return cls.from_data_source(
"squad_v2",
train_file,
val_file,
test_file,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)
53 changes: 53 additions & 0 deletions flash_examples/question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

# 1. Create the DataModule
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json", "./data/")
download_data("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json", "./data/")

datamodule = QuestionAnsweringData.from_squad_v2(
train_file="./data/train-v2.0.json",
val_file="./data/dev-v2.0.json",
)

# 2. Build the task
model = QuestionAnsweringTask()

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule)

# 4. Summarize some text!
predictions = model.predict(({
"context": [
"""
The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
century, and it continued to evolve over the succeeding centuries.
"""
],
"question": ["When were the Normans in Normandy?"]
}, "question", "context", None))
print(predictions[0])

# 5. Save the model!
trainer.save_checkpoint("question_answering_on_sqaud_v2.pt")
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
),
# pytest.param("finetuning", "object_detection.py"), # TODO: takes too long.
pytest.param(
"question_answering.py",
marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
),
pytest.param(
"semantic_segmentation.py",
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
Expand Down
Loading