Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Speed up and fix graph tests (#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 14, 2021
1 parent d963de7 commit 3b88e74
Show file tree
Hide file tree
Showing 31 changed files with 61 additions and 76 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ jobs:
run: |
pip install git+https://github.com/facebookresearch/vissl.git@master
- name: Install graph test dependencies
if: matrix.topic[0] == 'graph'
run: |
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install torch-sparse -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
- name: Install dependencies
run: |
python --version
Expand Down Expand Up @@ -166,8 +173,8 @@ jobs:
uses: actions/cache@v2
with:
path: data # This path is specific to Ubuntu
key: lightning-flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }}
restore-keys: lightning-flash-datasets-
key: flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }}
restore-keys: flash-datasets-

- name: Tests
env:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))

### Fixed

## [0.5.0] - 2021-09-07
Expand Down
3 changes: 1 addition & 2 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
Expand All @@ -23,7 +22,7 @@

def from_urban8k(
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> AudioClassificationData:
"""Downloads and loads the Urban 8k sounds images data set."""
Expand Down
3 changes: 1 addition & 2 deletions flash/audio/speech_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
Expand All @@ -23,7 +22,7 @@
def from_timit(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> SpeechRecognitionData:
"""Downloads and loads the timit data set."""
Expand Down
28 changes: 12 additions & 16 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -98,7 +97,7 @@ def __init__(
data_fetcher: Optional[BaseDataFetcher] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
) -> None:

Expand Down Expand Up @@ -138,13 +137,10 @@ def __init__(

self.batch_size = batch_size

# TODO: figure out best solution for setting num_workers
if num_workers is None:
if platform.system() in ("Darwin", "Windows"):
num_workers = 0
else:
num_workers = os.cpu_count()
num_workers = 0
self.num_workers = num_workers

self.sampler = sampler

self.set_running_stages()
Expand Down Expand Up @@ -468,7 +464,7 @@ def from_data_source(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -569,7 +565,7 @@ def from_folders(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -643,7 +639,7 @@ def from_files(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -720,7 +716,7 @@ def from_tensors(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -807,7 +803,7 @@ def from_numpy(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -893,7 +889,7 @@ def from_json(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
field: Optional[str] = None,
**preprocess_kwargs: Any,
Expand Down Expand Up @@ -1003,7 +999,7 @@ def from_csv(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -1087,7 +1083,7 @@ def from_datasets(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -1168,7 +1164,7 @@ def from_fiftyone(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.core.data.data_module.DataModule` object
Expand Down
2 changes: 1 addition & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _compare_version(package: str, op, version) -> bool:
_ALBUMENTATIONS_AVAILABLE = _module_available("albumentations")

if _PIL_AVAILABLE:
from PIL import Image
from PIL import Image # noqa: F401
else:

class MetaImage(type):
Expand Down
3 changes: 1 addition & 2 deletions flash/graph/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.utilities.flash_cli import FlashCLI
from flash.graph import GraphClassificationData, GraphClassifier
Expand All @@ -23,7 +22,7 @@ def from_tu_dataset(
name: str = "KKI",
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> GraphClassificationData:
"""Downloads and loads the TU Dataset."""
Expand Down
5 changes: 2 additions & 3 deletions flash/image/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
Expand All @@ -22,7 +21,7 @@

def from_hymenoptera(
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> ImageClassificationData:
"""Downloads and loads the Hymenoptera (Ants, Bees) data set."""
Expand All @@ -38,7 +37,7 @@ def from_hymenoptera(

def from_movie_posters(
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> ImageClassificationData:
"""Downloads and loads the movie posters genre classification data set."""
Expand Down
4 changes: 2 additions & 2 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def from_data_frame(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down Expand Up @@ -225,7 +225,7 @@ def from_csv(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
Expand Down
3 changes: 1 addition & 2 deletions flash/image/detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
Expand All @@ -23,7 +22,7 @@
def from_coco_128(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> ObjectDetectionData:
"""Downloads and loads the COCO 128 data set."""
Expand Down
6 changes: 3 additions & 3 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def from_coco(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
Expand Down Expand Up @@ -279,7 +279,7 @@ def from_voc(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
Expand Down Expand Up @@ -358,7 +358,7 @@ def from_via(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.detection.data.ObjectDetectionData` object from the given data folders
Expand Down
2 changes: 1 addition & 1 deletion flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> InstanceSegmentationData:
Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def from_coco(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the
Expand Down Expand Up @@ -171,7 +171,7 @@ def from_voc(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.instance_segmentation.data.InstanceSegmentationData` object from the
Expand Down
2 changes: 1 addition & 1 deletion flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
parser: Optional[Callable] = None,
**preprocess_kwargs,
) -> KeypointDetectionData:
Expand Down
2 changes: 1 addition & 1 deletion flash/image/keypoint_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def from_coco(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data
Expand Down
3 changes: 1 addition & 2 deletions flash/image/segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
Expand All @@ -24,7 +23,7 @@ def from_carla(
num_classes: int = 21,
val_split: float = 0.1,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> SemanticSegmentationData:
"""Downloads and loads the CARLA capture data set."""
Expand Down
4 changes: 2 additions & 2 deletions flash/image/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def from_data_source(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs: Any,
) -> "DataModule":

Expand Down Expand Up @@ -376,7 +376,7 @@ def from_folders(
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
num_classes: Optional[int] = None,
labels_map: Dict[int, Tuple[int, int, int]] = None,
**preprocess_kwargs,
Expand Down
3 changes: 1 addition & 2 deletions flash/image/style_transfer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional

import flash
from flash.core.data.utils import download_data
Expand All @@ -24,7 +23,7 @@

def from_coco_128(
batch_size: int = 4,
num_workers: Optional[int] = None,
num_workers: int = 0,
**preprocess_kwargs,
) -> StyleTransferData:
"""Downloads and loads the COCO 128 data set."""
Expand Down
Loading

0 comments on commit 3b88e74

Please sign in to comment.