Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Rename valid_ to val_ #197

Merged
merged 2 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'da
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

Expand Down Expand Up @@ -205,11 +205,11 @@ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/')

# 2. Load the data
datamodule = SummarizationData.from_files(
train_file="data/xsum/train.csv",
valid_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
train_file="data/xsum/train.csv",
val_file="data/xsum/valid.csv",
test_file="data/xsum/test.csv",
input="input",
target="target"
)

# 3. Build the model
Expand Down
46 changes: 23 additions & 23 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DataModule(pl.LightningDataModule):

Args:
train_dataset: Dataset for training. Defaults to None.
valid_dataset: Dataset for validating model performance during training. Defaults to None.
val_dataset: Dataset for validating model performance during training. Defaults to None.
test_dataset: Dataset to test model performance. Defaults to None.
predict_dataset: Dataset to predict model performance. Defaults to None.
num_workers: The number of workers to use for parallelized loading. Defaults to None.
Expand All @@ -49,7 +49,7 @@ class DataModule(pl.LightningDataModule):
def __init__(
self,
train_dataset: Optional[Dataset] = None,
valid_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
predict_dataset: Optional[Dataset] = None,
batch_size: int = 1,
Expand All @@ -58,14 +58,14 @@ def __init__(

super().__init__()
self._train_ds = train_dataset
self._valid_ds = valid_dataset
self._val_ds = val_dataset
self._test_ds = test_dataset
self._predict_ds = predict_dataset

if self._train_ds:
self.train_dataloader = self._train_dataloader

if self._valid_ds:
if self._val_ds:
self.val_dataloader = self._val_dataloader

if self._test_ds:
Expand Down Expand Up @@ -104,8 +104,8 @@ def set_running_stages(self):
if self._train_ds:
self.set_dataset_attribute(self._train_ds, 'running_stage', RunningStage.TRAINING)

if self._valid_ds:
self.set_dataset_attribute(self._valid_ds, 'running_stage', RunningStage.VALIDATING)
if self._val_ds:
self.set_dataset_attribute(self._val_ds, 'running_stage', RunningStage.VALIDATING)

if self._test_ds:
self.set_dataset_attribute(self._test_ds, 'running_stage', RunningStage.TESTING)
Expand All @@ -130,13 +130,13 @@ def _train_dataloader(self) -> DataLoader:
)

def _val_dataloader(self) -> DataLoader:
valid_ds: Dataset = self._valid_ds() if isinstance(self._valid_ds, Callable) else self._valid_ds
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
return DataLoader(
valid_ds,
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(valid_ds, RunningStage.VALIDATING)
collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
)

def _test_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -214,10 +214,10 @@ def autogenerate_dataset(
return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)

@staticmethod
def train_valid_test_split(
def train_val_test_split(
dataset: torch.utils.data.Dataset,
train_split: Optional[Union[float, int]] = None,
valid_split: Optional[Union[float, int]] = None,
val_split: Optional[Union[float, int]] = None,
test_split: Optional[Union[float, int]] = None,
seed: Optional[int] = 1234,
) -> Tuple[Dataset, Optional[Dataset], Optional[Dataset]]:
Expand All @@ -227,11 +227,11 @@ def train_valid_test_split(
dataset: Dataset to be split.
train_split: If Float, ratio of data to be contained within the train dataset. If Int,
number of samples to be contained within train dataset.
valid_split: If Float, ratio of data to be contained within the validation dataset. If Int,
val_split: If Float, ratio of data to be contained within the validation dataset. If Int,
number of samples to be contained within test dataset.
test_split: If Float, ratio of data to be contained within the test dataset. If Int,
number of samples to be contained within test dataset.
seed: Used for the train/val splits when valid_split is not None.
seed: Used for the train/val splits when val_split is not None.

"""
n = len(dataset)
Expand All @@ -243,12 +243,12 @@ def train_valid_test_split(
else:
_test_length = test_split

if valid_split is None:
if val_split is None:
_val_length = 0
elif isinstance(valid_split, float):
_val_length = int(n * valid_split)
elif isinstance(val_split, float):
_val_length = int(n * val_split)
else:
_val_length = valid_split
_val_length = val_split

if train_split is None:
_train_length = n - _val_length - _test_length
Expand All @@ -265,7 +265,7 @@ def train_valid_test_split(
train_ds, val_ds, test_ds = torch.utils.data.random_split(
dataset, [_train_length, _val_length, _test_length], generator
)
if valid_split is None:
if val_split is None:
val_ds = None
if test_split is None:
test_ds = None
Expand Down Expand Up @@ -293,7 +293,7 @@ def _generate_dataset_if_possible(
def from_load_data_inputs(
cls,
train_load_data_input: Optional[Any] = None,
valid_load_data_input: Optional[Any] = None,
val_load_data_input: Optional[Any] = None,
test_load_data_input: Optional[Any] = None,
predict_load_data_input: Optional[Any] = None,
preprocess: Optional[Preprocess] = None,
Expand All @@ -306,7 +306,7 @@ def from_load_data_inputs(
Args:
cls: ``DataModule`` subclass
train_load_data_input: Data to be received by the ``train_load_data`` function from this ``Preprocess``
valid_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
val_load_data_input: Data to be received by the ``val_load_data`` function from this ``Preprocess``
test_load_data_input: Data to be received by the ``test_load_data`` function from this ``Preprocess``
predict_load_data_input: Data to be received by the ``predict_load_data`` function from this ``Preprocess``
kwargs: Any extra arguments to instantiate the provided ``DataModule``
Expand All @@ -322,8 +322,8 @@ def from_load_data_inputs(
train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
)
valid_dataset = cls._generate_dataset_if_possible(
valid_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
val_dataset = cls._generate_dataset_if_possible(
val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
)
test_dataset = cls._generate_dataset_if_possible(
test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline
Expand All @@ -333,7 +333,7 @@ def from_load_data_inputs(
)
datamodule = cls(
train_dataset=train_dataset,
valid_dataset=valid_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
predict_dataset=predict_dataset,
**kwargs
Expand Down
4 changes: 2 additions & 2 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ class Preprocess(Properties, torch.nn.Module):
def __init__(
self,
train_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
valid_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Callable, Module, Dict[str, Callable]]] = None,
):
super().__init__()
self.train_transform = convert_to_modules(train_transform)
self.valid_transform = convert_to_modules(valid_transform)
self.val_transform = convert_to_modules(val_transform)
self.test_transform = convert_to_modules(test_transform)
self.predict_transform = convert_to_modules(predict_transform)

Expand Down
34 changes: 17 additions & 17 deletions flash/tabular/classification/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def state(self) -> TabularState:
@staticmethod
def generate_state(
train_df: DataFrame,
valid_df: Optional[DataFrame],
val_df: Optional[DataFrame],
test_df: Optional[DataFrame],
predict_df: Optional[DataFrame],
target_col: str,
Expand All @@ -100,8 +100,8 @@ def generate_state(

dfs = [train_df]

if valid_df is not None:
dfs += [valid_df]
if val_df is not None:
dfs += [val_df]

if test_df is not None:
dfs += [test_df]
Expand Down Expand Up @@ -197,7 +197,7 @@ def from_csv(
train_csv: Optional[str] = None,
categorical_cols: Optional[List] = None,
numerical_cols: Optional[List] = None,
valid_csv: Optional[str] = None,
val_csv: Optional[str] = None,
test_csv: Optional[str] = None,
predict_csv: Optional[str] = None,
batch_size: int = 8,
Expand All @@ -215,7 +215,7 @@ def from_csv(
target_col: The column containing the class id.
categorical_cols: The list of categorical columns.
numerical_cols: The list of numerical columns.
valid_csv: Validation data csv file.
val_csv: Validation data csv file.
test_csv: Test data csv file.
batch_size: The batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand All @@ -234,7 +234,7 @@ def from_csv(
text_data = TabularData.from_files("train.csv", label_field="class", text_field="sentence")
"""
train_df = pd.read_csv(train_csv, **pandas_kwargs)
valid_df = pd.read_csv(valid_csv, **pandas_kwargs) if valid_csv else None
val_df = pd.read_csv(val_csv, **pandas_kwargs) if val_csv else None
test_df = pd.read_csv(test_csv, **pandas_kwargs) if test_csv else None
predict_df = pd.read_csv(predict_csv, **pandas_kwargs) if predict_csv else None

Expand All @@ -243,7 +243,7 @@ def from_csv(
target_col,
categorical_cols,
numerical_cols,
valid_df,
val_df,
test_df,
predict_df,
batch_size,
Expand All @@ -268,21 +268,21 @@ def emb_sizes(self) -> list:
@staticmethod
def _split_dataframe(
train_df: DataFrame,
valid_df: Optional[DataFrame] = None,
val_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
val_size: float = None,
test_size: float = None,
):
if valid_df is None and isinstance(val_size, float) and isinstance(test_size, float):
if val_df is None and isinstance(val_size, float) and isinstance(test_size, float):
assert 0 < val_size < 1
assert 0 < test_size < 1
train_df, valid_df = train_test_split(train_df, test_size=(val_size + test_size))
train_df, val_df = train_test_split(train_df, test_size=(val_size + test_size))

if test_df is None and isinstance(test_size, float):
assert 0 < test_size < 1
valid_df, test_df = train_test_split(valid_df, test_size=test_size)
val_df, test_df = train_test_split(val_df, test_size=test_size)

return train_df, valid_df, test_df
return train_df, val_df, test_df

@staticmethod
def _sanetize_cols(cat_cols: Optional[List], num_cols: Optional[List]):
Expand All @@ -298,7 +298,7 @@ def from_df(
target_col: str,
categorical_cols: Optional[List] = None,
numerical_cols: Optional[List] = None,
valid_df: Optional[DataFrame] = None,
val_df: Optional[DataFrame] = None,
test_df: Optional[DataFrame] = None,
predict_df: Optional[DataFrame] = None,
batch_size: int = 8,
Expand All @@ -316,7 +316,7 @@ def from_df(
target_col: The column containing the class id.
categorical_cols: The list of categorical columns.
numerical_cols: The list of numerical columns.
valid_df: Validation data DataFrame.
val_df: Validation data DataFrame.
test_df: Test data DataFrame.
batch_size: The batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand All @@ -334,13 +334,13 @@ def from_df(
"""
categorical_cols, numerical_cols = cls._sanetize_cols(categorical_cols, numerical_cols)

train_df, valid_df, test_df = cls._split_dataframe(train_df, valid_df, test_df, val_size, test_size)
train_df, val_df, test_df = cls._split_dataframe(train_df, val_df, test_df, val_size, test_size)

preprocess_cls = preprocess_cls or cls.preprocess_cls

preprocess_state = preprocess_cls.generate_state(
train_df,
valid_df,
val_df,
test_df,
predict_df,
target_col,
Expand All @@ -353,7 +353,7 @@ def from_df(

return cls.from_load_data_inputs(
train_load_data_input=train_df,
valid_load_data_input=valid_df,
val_load_data_input=val_df,
test_load_data_input=test_df,
predict_load_data_input=predict_df,
batch_size=batch_size,
Expand Down
8 changes: 4 additions & 4 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def from_files(
target: Optional[str] = 'labels',
filetype: str = "csv",
backbone: str = "prajjwal1/bert-tiny",
valid_file: Optional[str] = None,
val_file: Optional[str] = None,
test_file: Optional[str] = None,
predict_file: Optional[str] = None,
max_length: int = 128,
Expand All @@ -255,7 +255,7 @@ def from_files(
target: The field storing the class id of the associated text.
filetype: .csv or .json
backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer.
valid_file: Path to validation data.
val_file: Path to validation data.
test_file: Path to test data.
batch_size: the batchsize to use for parallel loading. Defaults to 64.
num_workers: The number of workers to use for parallelized loading.
Expand Down Expand Up @@ -287,7 +287,7 @@ def from_files(

return cls.from_load_data_inputs(
train_load_data_input=train_file,
valid_load_data_input=valid_file,
val_load_data_input=val_file,
test_load_data_input=test_file,
predict_load_data_input=predict_file,
batch_size=batch_size,
Expand Down Expand Up @@ -327,7 +327,7 @@ def from_file(
target=None,
filetype=filetype,
backbone=backbone,
valid_file=None,
val_file=None,
test_file=None,
predict_file=predict_file,
max_length=max_length,
Expand Down
6 changes: 3 additions & 3 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def from_files(
target: Optional[str] = None,
filetype: str = "csv",
backbone: str = "sshleifer/tiny-mbart",
valid_file: Optional[str] = None,
val_file: Optional[str] = None,
test_file: Optional[str] = None,
predict_file: Optional[str] = None,
max_source_length: int = 128,
Expand All @@ -185,7 +185,7 @@ def from_files(
target: The field storing the target translation text.
filetype: ``csv`` or ``json`` File
backbone: Tokenizer backbone to use, can use any HuggingFace tokenizer.
valid_file: Path to validation data.
val_file: Path to validation data.
test_file: Path to test data.
max_source_length: Maximum length of the source text. Any text longer will be truncated.
max_target_length: Maximum length of the target text. Any text longer will be truncated.
Expand Down Expand Up @@ -217,7 +217,7 @@ def from_files(

return cls.from_load_data_inputs(
train_load_data_input=train_file,
valid_load_data_input=valid_file,
val_load_data_input=val_file,
test_load_data_input=test_file,
predict_load_data_input=predict_file,
batch_size=batch_size,
Expand Down
Loading