Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Added field parameter to the from_json method with other required cha… #585

9 changes: 5 additions & 4 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ def from_json(
batch_size: int = 4,
num_workers: Optional[int] = None,
sampler: Optional[Sampler] = None,
field: str = None,
karthikrangasai marked this conversation as resolved.
Show resolved Hide resolved
**preprocess_kwargs: Any,
) -> 'DataModule':
"""Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the
Expand Down Expand Up @@ -889,10 +890,10 @@ def from_json(
"""
return cls.from_data_source(
DefaultDataSources.JSON,
(train_file, input_fields, target_fields),
(val_file, input_fields, target_fields),
(test_file, input_fields, target_fields),
(predict_file, input_fields, target_fields),
(train_file, input_fields, target_fields, field),
(val_file, input_fields, target_fields, field),
(test_file, input_fields, target_fields, field),
(predict_file, input_fields, target_fields, field),
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
Expand Down
27 changes: 21 additions & 6 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def load_data(
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
file, input, target = data
if self.filetype == 'json':
file, input, target, field = data
else:
file, input, target = data

data_files = {}

Expand All @@ -120,13 +123,25 @@ def load_data(
# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING and not torch.cuda.is_available():
try:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
if self.filetype == 'json' and field is not None:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'],
field=field)[0]
})
else:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
except Exception:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)

if not self.predicting:
if isinstance(target, List):
Expand Down
27 changes: 21 additions & 6 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,36 @@ def __init__(
def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset':
if columns is None:
columns = ["input_ids", "attention_mask", "labels"]
file, input, target = data
if self.filetype == 'json':
file, input, target, field = data
else:
file, input, target = data
data_files = {}
stage = self._running_stage.value
data_files[stage] = str(file)

# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING:
try:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
if self.filetype == 'json' and field is not None:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'],
field=field)[0]
})
else:
dataset_dict = DatasetDict({
stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
})
except Exception:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)
if self.filetype == 'json' and field is not None:
dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)

dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True)
dataset_dict.set_format(columns=columns)
Expand Down