diff --git a/CHANGELOG.md b/CHANGELOG.md index be555b48c9..97085839cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587)) +- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 0cdfc99ed3..5831c84a68 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -889,6 +889,7 @@ def from_json( batch_size: int = 4, num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, + field: Optional[str] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the @@ -920,6 +921,7 @@ def from_json( batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + field: To specify the field that holds the data in the JSON file. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -936,13 +938,35 @@ def from_json( "to_tensor_transform": torch.as_tensor, }, ) + + # In the case where the data is of the form: + # { + # "version": 0.0.x, + # "data": [ + # { + # "input_field" : "input_data", + # "target_field" : "target_output" + # }, + # ... + # ] + # } + + data_module = DataModule.from_json( + "input", + "target", + train_file="train_data.json", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + feild="data" + ) """ return cls.from_data_source( DefaultDataSources.JSON, - (train_file, input_fields, target_fields), - (val_file, input_fields, target_fields), - (test_file, input_fields, target_fields), - (predict_file, input_fields, target_fields), + (train_file, input_fields, target_fields, field), + (val_file, input_fields, target_fields, field), + (test_file, input_fields, target_fields, field), + (predict_file, input_fields, target_fields, field), train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index b6cb4672f1..d8039dcbc4 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -110,7 +110,10 @@ def load_data( dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: - file, input, target = data + if self.filetype == 'json': + file, input, target, field = data + else: + file, input, target = data data_files = {} @@ -120,13 +123,25 @@ def load_data( # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING and not torch.cuda.is_available(): try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == 'json' and field is not None: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], + field=field)[0] + }) + else: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) if not self.predicting: if isinstance(target, List): diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 4ebb537dbe..decb43fc53 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -98,7 +98,10 @@ def __init__( def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': if columns is None: columns = ["input_ids", "attention_mask", "labels"] - file, input, target = data + if self.filetype == 'json': + file, input, target, field = data + else: + file, input, target = data data_files = {} stage = self._running_stage.value data_files[stage] = str(file) @@ -106,13 +109,25 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING: try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == 'json' and field is not None: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], + field=field)[0] + }) + else: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True) dataset_dict.set_format(columns=columns) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index d5a3b680f9..b92c3757cc 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -44,6 +44,12 @@ {"sentence": "this is a sentence three","lab":0} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"sentence": "this is a sentence one","lab":0}, +{"sentence": "this is a sentence two","lab":1}, +{"sentence": "this is a sentence three","lab":0}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -57,6 +63,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -99,6 +111,18 @@ def test_from_json(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TextClassificationData.from_json( + "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): diff --git a/tests/text/seq2seq/question_answering/test_data.py b/tests/text/seq2seq/question_answering/test_data.py index 2db170464e..83f7824e57 100644 --- a/tests/text/seq2seq/question_answering/test_data.py +++ b/tests/text/seq2seq/question_answering/test_data.py @@ -33,6 +33,12 @@ {"input": "this is a question three","target":"this is an answer three"} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a question one","target":"this is an answer one"}, +{"input": "this is a question two","target":"this is an answer two"}, +{"input": "this is a question three","target":"this is an answer three"}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -106,3 +118,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = QuestionAnsweringData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 2ab09f3636..a1120854ea 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -22,15 +22,21 @@ TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing TEST_CSV_DATA = """input,target -this is a sentence one,this is a translated sentence one -this is a sentence two,this is a translated sentence two -this is a sentence three,this is a translated sentence three +this is a sentence one,this is a summarized sentence one +this is a sentence two,this is a summarized sentence two +this is a sentence three,this is a summarized sentence three """ TEST_JSON_DATA = """ -{"input": "this is a sentence one","target":"this is a translated sentence one"} -{"input": "this is a sentence two","target":"this is a translated sentence two"} -{"input": "this is a sentence three","target":"this is a translated sentence three"} +{"input": "this is a sentence one","target":"this is a summarized sentence one"} +{"input": "this is a sentence two","target":"this is a summarized sentence two"} +{"input": "this is a sentence three","target":"this is a summarized sentence three"} +""" + +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a summarized sentence one"}, +{"input": "this is a sentence two","target":"this is a summarized sentence two"}, +{"input": "this is a sentence three","target":"this is a summarized sentence three"}]} """ @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -106,3 +118,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = SummarizationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 244cb27d4a..27162491a0 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -33,6 +33,12 @@ {"input": "this is a sentence three","target":"this is a translated sentence three"} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a translated sentence one"}, +{"input": "this is a sentence two","target":"this is a translated sentence two"}, +{"input": "this is a sentence three","target":"this is a translated sentence three"}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -86,3 +98,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TranslationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch