Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix text classification data loading #888

Merged
merged 3 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792))

- Fixed a bug where loading text data with additional non-numeric columns (not input or target) would give an error ([#888](https://github.com/PyTorchLightning/lightning-flash/pull/888))


## [0.5.0] - 2021-09-07

Expand Down
7 changes: 5 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,14 @@ def load_data(
# rename label column
hf_dataset = hf_dataset.rename_column(target, DefaultDataKeys.TARGET)

# remove extra columns
extra_columns = set(hf_dataset.column_names) - {input, DefaultDataKeys.TARGET}
hf_dataset = hf_dataset.remove_columns(extra_columns)

# tokenize
hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=input), batched=True)
hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=input), batched=True, remove_columns=[input])

# set format
hf_dataset = hf_dataset.remove_columns([input]) # just leave the numerical columns
hf_dataset.set_format("torch")

return hf_dataset
Expand Down
3 changes: 1 addition & 2 deletions tests/text/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def test_load_from_checkpoint_dependency_error():
"cli_args",
(
["flash", "text_classification", "--trainer.fast_dev_run", "True"],
# TODO: update this to work with Pietro's new text data loading (separate PR)
# ["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"],
["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"],
),
)
def test_cli(cli_args):
Expand Down