Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/769-text-from-lists
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Sep 29, 2021
2 parents 5101270 + 0a28672 commit 4dd3e91
Show file tree
Hide file tree
Showing 14 changed files with 963 additions and 19 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
/.github/*.md @edenlightning @ethanwharris @ananyahjha93
/.github/ISSUE_TEMPLATE/*.md @edenlightning @ethanwharris @ananyahjha93
/docs/source/conf.py @borda @ethanwharris @ananyahjha93
/flash/core/integrations/labelstudio @KonstantinKorotaev @niklub
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `LabelStudio` integration ([#554](https://github.com/PyTorchLightning/lightning-flash/pull/554))

- Added support `learn2learn` training_strategy for `ImageClassifier` ([#737](https://github.com/PyTorchLightning/lightning-flash/pull/737))

- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682))
Expand Down
139 changes: 138 additions & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _train_dataloader(self) -> DataLoader:
else:
drop_last = len(train_ds) > self.batch_size
pin_memory = True
persistent_workers = self.num_workers > 0

if self.sampler is None:
sampler = None
Expand Down Expand Up @@ -317,12 +318,14 @@ def _train_dataloader(self) -> DataLoader:
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
persistent_workers=persistent_workers,
)

def _val_dataloader(self) -> DataLoader:
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory = True
persistent_workers = self.num_workers > 0

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
Expand All @@ -340,12 +343,14 @@ def _val_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
persistent_workers=persistent_workers,
)

def _test_dataloader(self) -> DataLoader:
test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds
collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory = True
persistent_workers = False

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
Expand All @@ -363,6 +368,7 @@ def _test_dataloader(self) -> DataLoader:
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn,
persistent_workers=persistent_workers,
)

def _predict_dataloader(self) -> DataLoader:
Expand All @@ -375,6 +381,7 @@ def _predict_dataloader(self) -> DataLoader:

collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
pin_memory = True
persistent_workers = False

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_predict_dataset(
Expand All @@ -386,7 +393,12 @@ def _predict_dataloader(self) -> DataLoader:
)

return DataLoader(
predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=collate_fn,
persistent_workers=persistent_workers,
)

@property
Expand Down Expand Up @@ -1234,3 +1246,128 @@ def from_fiftyone(
num_workers=num_workers,
**preprocess_kwargs,
)

@classmethod
def from_labelstudio(
cls,
export_json: str = None,
train_export_json: str = None,
val_export_json: str = None,
test_export_json: str = None,
predict_export_json: str = None,
data_folder: str = None,
train_data_folder: str = None,
val_data_folder: str = None,
test_data_folder: str = None,
predict_data_folder: str = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object
from the given export file and data directory using the
:class:`~flash.core.data.data_source.DataSource` of name
:attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS`
from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
Args:
export_json: path to label studio export file
train_export_json: path to label studio export file for train set,
overrides export_json if specified
val_export_json: path to label studio export file for validation
test_export_json: path to label studio export file for test
predict_export_json: path to label studio export file for predict
data_folder: path to label studio data folder
train_data_folder: path to label studio data folder for train data set,
overrides data_folder if specified
val_data_folder: path to label studio data folder for validation data
test_data_folder: path to label studio data folder for test data
predict_data_folder: path to label studio data folder for predict data
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
Examples::
data_module = DataModule.from_labelstudio(
export_json='project.json',
data_folder='label-studio/media/upload',
val_split=0.8,
)
"""
data = {
"data_folder": data_folder,
"export_json": export_json,
"split": val_split,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
train_data = None
val_data = None
test_data = None
predict_data = None
if (train_data_folder or data_folder) and train_export_json:
train_data = {
"data_folder": train_data_folder or data_folder,
"export_json": train_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (val_data_folder or data_folder) and val_export_json:
val_data = {
"data_folder": val_data_folder or data_folder,
"export_json": val_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (test_data_folder or data_folder) and test_export_json:
test_data = {
"data_folder": test_data_folder or data_folder,
"export_json": test_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
if (predict_data_folder or data_folder) and predict_export_json:
predict_data = {
"data_folder": predict_data_folder or data_folder,
"export_json": predict_export_json,
"multi_label": preprocess_kwargs.get("multi_label", False),
}
return cls.from_data_source(
DefaultDataSources.LABELSTUDIO,
train_data=train_data if train_data else data,
val_data=val_data,
test_data=test_data,
predict_data=predict_data,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)
1 change: 1 addition & 0 deletions flash/core/data/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class DefaultDataSources(LightningEnum):
DATAFRAME = "data_frame"
LISTS = "lists"
SENTENCES = "sentences"
LABELSTUDIO = "labelstudio"

# TODO: Create a FlashEnum class???
def __hash__(self) -> int:
Expand Down
Loading

0 comments on commit 4dd3e91

Please sign in to comment.