Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add from_lists to TextClassificationData #805

Merged
merged 7 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

- Added support for `from_list` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))
Expand Down
121 changes: 120 additions & 1 deletion flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _multilabel_target(targets, element):

def load_data(
self,
data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]],
data: Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
Expand Down Expand Up @@ -278,6 +278,49 @@ def __setstate__(self, state):
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextListDataSource(TextDataSource):
def load_data(
self,
data: Tuple[List[str], List[Any]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
input, target = data
hf_dataset = Dataset.from_dict({"input": input, "labels": target})

if not self.predicting:
dataset.multi_label = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding support for multi label ?

if self.training:
labels = list(sorted(list(set(hf_dataset["labels"]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, "labels"))

hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input="input"), batched=True)
hf_dataset.set_format("torch", columns=columns)

return hf_dataset

def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):
def __init__(self, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
Expand Down Expand Up @@ -330,6 +373,7 @@ def __init__(
DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length),
DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length),
"data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length),
"list": TextListDataSource(self.backbone, max_length=max_length),
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"sentences": TextSentencesDataSource(self.backbone, max_length=max_length),
},
default_data_source="sentences",
Expand Down Expand Up @@ -450,3 +494,78 @@ def from_data_frame(
sampler=sampler,
**preprocess_kwargs,
)

@classmethod
def from_list(
cls,
train_data: Optional[List[str]] = None,
train_targets: Optional[List[Any]] = None,
val_data: Optional[List[str]] = None,
val_targets: Optional[List[Any]] = None,
test_data: Optional[List[str]] = None,
test_targets: Optional[List[Any]] = None,
predict_data: Optional[List[str]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python
list.

Args:
train_data: A list to use as the train inputs.
train_targets: A sequence of targets (one per train input) to use as the train targets.
val_data: A list to use as the validation inputs.
val_targets: A sequence of targets (one per validation input) to use as the validation targets.
test_data: A list to use as the test inputs.
test_targets: A sequence of targets (one per test input) to use as the test targets.
predict_data: A list to use when predicting.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.

Returns:
The constructed data module.
"""
return cls.from_data_source(
"list",
(train_data, train_targets),
(val_data, val_targets),
(test_data, test_targets),
predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)
17 changes: 17 additions & 0 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TextDataSource,
TextFileDataSource,
TextJSONDataSource,
TextListDataSource,
TextSentencesDataSource,
)
from tests.helpers.utils import _TEXT_TESTING
Expand Down Expand Up @@ -58,6 +59,10 @@
)


TEST_LIST_DATA = ["this is a sentence one", "this is a sentence two", "this is a sentence three"]
TEST_LIST_TARGETS = [0, 1, 0]


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
path.write_text(TEST_CSV_DATA)
Expand Down Expand Up @@ -141,6 +146,17 @@ def test_from_data_frame():
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_list():
dm = TextClassificationData.from_list(
backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand All @@ -157,6 +173,7 @@ def test_text_module_not_found_error():
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextDataFrameDataSource, {}),
(TextListDataSource, {}),
(TextSentencesDataSource, {}),
],
)
Expand Down