Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add from_lists to TextClassificationData #805

Merged
merged 7 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

- Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))
Expand Down
3 changes: 3 additions & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class DefaultDataSources(LightningEnum):
JSON = "json"
DATASETS = "datasets"
FIFTYONE = "fiftyone"
DATAFRAME = "data_frame"
LISTS = "lists"
SENTENCES = "sentences"
LABELSTUDIO = "labelstudio"

# TODO: Create a FlashEnum class???
Expand Down
138 changes: 133 additions & 5 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _multilabel_target(targets, element):

def load_data(
self,
data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]],
data: Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
Expand Down Expand Up @@ -279,6 +279,55 @@ def __setstate__(self, state):
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextListDataSource(TextDataSource):
def load_data(
self,
data: Tuple[List[str], Union[List[Any], List[List[Any]]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
input, target = data
hf_dataset = Dataset.from_dict({"input": input, "labels": target})

if not self.predicting:
if isinstance(target[0], List):
# multi-target
dataset.multi_label = True
dataset.num_classes = len(target[0])
self.set_state(LabelsState(target))
else:
dataset.multi_label = False
if self.training:
labels = list(sorted(list(set(hf_dataset["labels"]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, "labels"))

hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input="input"), batched=True)
hf_dataset.set_format("torch", columns=columns)

return hf_dataset

def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):
def __init__(self, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
Expand Down Expand Up @@ -330,13 +379,14 @@ def __init__(
data_sources={
DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length),
DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length),
"data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length),
"sentences": TextSentencesDataSource(self.backbone, max_length=max_length),
DefaultDataSources.DATAFRAME: TextDataFrameDataSource(self.backbone, max_length=max_length),
DefaultDataSources.LISTS: TextListDataSource(self.backbone, max_length=max_length),
DefaultDataSources.SENTENCES: TextSentencesDataSource(self.backbone, max_length=max_length),
DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource(
backbone=self.backbone, max_length=max_length
),
},
default_data_source="sentences",
default_data_source=DefaultDataSources.SENTENCES,
deserializer=TextDeserializer(backbone, max_length),
)

Expand Down Expand Up @@ -437,7 +487,7 @@ def from_data_frame(
The constructed data module.
"""
return cls.from_data_source(
"data_frame",
DefaultDataSources.DATAFRAME,
(train_data_frame, input_field, target_fields),
(val_data_frame, input_field, target_fields),
(test_data_frame, input_field, target_fields),
Expand All @@ -454,3 +504,81 @@ def from_data_frame(
sampler=sampler,
**preprocess_kwargs,
)

@classmethod
def from_lists(
cls,
train_data: Optional[List[str]] = None,
train_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
val_data: Optional[List[str]] = None,
val_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
test_data: Optional[List[str]] = None,
test_targets: Optional[Union[List[Any], List[List[Any]]]] = None,
predict_data: Optional[List[str]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python
lists.
Args:
train_data: A list of sentences to use as the train inputs.
train_targets: A list of targets to use as the train targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
val_data: A list of sentences to use as the validation inputs.
val_targets: A list of targets to use as the validation targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
test_data: A list of sentences to use as the test inputs.
test_targets: A list of targets to use as the test targets. For multi-label classification, the targets
should be provided as a list of lists, where each inner list contains the targets for a sample.
predict_data: A list of sentences to use when predicting.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
"""
return cls.from_data_source(
DefaultDataSources.LISTS,
(train_data, train_targets),
(val_data, val_targets),
(test_data, test_targets),
predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)
48 changes: 46 additions & 2 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextDataSource,
TextFileDataSource,
TextJSONDataSource,
TextListDataSource,
TextSentencesDataSource,
)
from tests.helpers.utils import _TEXT_TESTING
Expand Down Expand Up @@ -54,10 +55,19 @@


TEST_DATA_FRAME_DATA = pd.DataFrame(
{"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]},
{
"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"],
"lab1": [0, 1, 0],
"lab2": [1, 0, 1],
},
)


TEST_LIST_DATA = ["this is a sentence one", "this is a sentence two", "this is a sentence three"]
TEST_LIST_TARGETS = [0, 1, 0]
TEST_LIST_TARGETS_MULTILABEL = [[0, 1], [1, 0], [0, 1]]


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
path.write_text(TEST_CSV_DATA)
Expand Down Expand Up @@ -134,13 +144,46 @@ def test_from_json_with_field(tmpdir):
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_data_frame():
dm = TextClassificationData.from_data_frame(
"sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
"sentence", "lab1", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_data_frame_multilabel():
dm = TextClassificationData.from_data_frame(
"sentence", ["lab1", "lab2"], backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert all([label in [0, 1] for label in batch["labels"][0]])
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_lists():
dm = TextClassificationData.from_lists(
backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_lists_multilabel():
dm = TextClassificationData.from_lists(
backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS_MULTILABEL, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert all([label in [0, 1] for label in batch["labels"][0]])
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand All @@ -157,6 +200,7 @@ def test_text_module_not_found_error():
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextDataFrameDataSource, {}),
(TextListDataSource, {}),
(TextSentencesDataSource, {}),
],
)
Expand Down