Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Added field parameter to the from_json method with other required cha… (
Browse files Browse the repository at this point in the history
#585)

* Added field parameter to the from_json method with other required changes.

* Updating field parameter type and CHANGELOG

* Added docs for the new parameter

* Add some tests

* Update flash/core/data/data_module.py

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Jul 14, 2021
1 parent f6e0d20 commit 87df19a
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 22 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587))

- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
32 changes: 28 additions & 4 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ def from_json(
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
field: Optional[str] = None,
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the
Expand Down Expand Up @@ -920,6 +921,7 @@ def from_json(
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
field: To specify the field that holds the data in the JSON file.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Expand All @@ -936,13 +938,35 @@ def from_json(
"to_tensor_transform": torch.as_tensor,
},
)
# In the case where the data is of the form:
# {
# "version": 0.0.x,
# "data": [
# {
# "input_field" : "input_data",
# "target_field" : "target_output"
# },
# ...
# ]
# }
data_module = DataModule.from_json(
"input",
"target",
train_file="train_data.json",
train_transform={
"to_tensor_transform": torch.as_tensor,
},
feild="data"
)
"""
return cls.from_data_source(
DefaultDataSources.JSON,
(train_file, input_fields, target_fields),
(val_file, input_fields, target_fields),
(test_file, input_fields, target_fields),
(predict_file, input_fields, target_fields),
(train_file, input_fields, target_fields, field),
(val_file, input_fields, target_fields, field),
(test_file, input_fields, target_fields, field),
(predict_file, input_fields, target_fields, field),
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
Expand Down
27 changes: 21 additions & 6 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def load_data(
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
file, input, target = data
if self.filetype == 'json':
file, input, target, field = data
else:
file, input, target = data

data_files = {}

Expand All @@ -120,13 +123,25 @@ def load_data(
# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING and not torch.cuda.is_available():
try:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
if self.filetype == 'json' and field is not None:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'],
field=field)[0]
})
else:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
except Exception:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)

if not self.predicting:
if isinstance(target, List):
Expand Down
27 changes: 21 additions & 6 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,36 @@ def __init__(
def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset':
if columns is None:
columns = ["input_ids", "attention_mask", "labels"]
file, input, target = data
if self.filetype == 'json':
file, input, target, field = data
else:
file, input, target = data
data_files = {}
stage = self._running_stage.value
data_files[stage] = str(file)

# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING:
try:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
if self.filetype == 'json' and field is not None:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'],
field=field)[0]
})
else:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
except Exception:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)

dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True)
dataset_dict.set_format(columns=columns)
Expand Down
24 changes: 24 additions & 0 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
{"sentence": "this is a sentence three","lab":0}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"sentence": "this is a sentence one","lab":0},
{"sentence": "this is a sentence two","lab":1},
{"sentence": "this is a sentence three","lab":0}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -57,6 +63,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -99,6 +111,18 @@ def test_from_json(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = TextClassificationData.from_json(
"sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand Down
24 changes: 24 additions & 0 deletions tests/text/seq2seq/question_answering/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
{"input": "this is a question three","target":"this is an answer three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a question one","target":"this is an answer one"},
{"input": "this is a question two","target":"this is an answer two"},
{"input": "this is a question three","target":"this is an answer three"}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -106,3 +118,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = QuestionAnsweringData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
36 changes: 30 additions & 6 deletions tests/text/seq2seq/summarization/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,21 @@
TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing

TEST_CSV_DATA = """input,target
this is a sentence one,this is a translated sentence one
this is a sentence two,this is a translated sentence two
this is a sentence three,this is a translated sentence three
this is a sentence one,this is a summarized sentence one
this is a sentence two,this is a summarized sentence two
this is a sentence three,this is a summarized sentence three
"""

TEST_JSON_DATA = """
{"input": "this is a sentence one","target":"this is a translated sentence one"}
{"input": "this is a sentence two","target":"this is a translated sentence two"}
{"input": "this is a sentence three","target":"this is a translated sentence three"}
{"input": "this is a sentence one","target":"this is a summarized sentence one"}
{"input": "this is a sentence two","target":"this is a summarized sentence two"}
{"input": "this is a sentence three","target":"this is a summarized sentence three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a sentence one","target":"this is a summarized sentence one"},
{"input": "this is a sentence two","target":"this is a summarized sentence two"},
{"input": "this is a sentence three","target":"this is a summarized sentence three"}]}
"""


Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -106,3 +118,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = SummarizationData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
24 changes: 24 additions & 0 deletions tests/text/seq2seq/translation/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
{"input": "this is a sentence three","target":"this is a translated sentence three"}
"""

TEST_JSON_DATA_FIELD = """{"data": [
{"input": "this is a sentence one","target":"this is a translated sentence one"},
{"input": "this is a sentence two","target":"this is a translated sentence two"},
{"input": "this is a sentence three","target":"this is a translated sentence three"}]}
"""


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
Expand All @@ -46,6 +52,12 @@ def json_data(tmpdir):
return path


def json_data_with_field(tmpdir):
path = Path(tmpdir) / "data.json"
path.write_text(TEST_JSON_DATA_FIELD)
return path


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
Expand Down Expand Up @@ -86,3 +98,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_json_with_field(tmpdir):
json_path = json_data_with_field(tmpdir)
dm = TranslationData.from_json(
"input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
)
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch

0 comments on commit 87df19a

Please sign in to comment.