Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
w
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 3, 2021
1 parent 41ac58c commit 513c951
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
9 changes: 5 additions & 4 deletions flash/core/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform
from typing import Any, Optional

import pytorch_lightning as pl
Expand Down Expand Up @@ -64,10 +65,10 @@ def __init__(

# TODO: figure out best solution for setting num_workers
if num_workers is None:
num_workers = os.cpu_count()
# if num_workers is None:
# # warnings.warn("Could not infer cpu count automatically, setting it to zero")
# num_workers = 0
if platform.system() == "Darwin":
num_workers = 0
else:
num_workers = os.cpu_count()
self.num_workers = num_workers

self._data_pipeline = None
Expand Down
7 changes: 6 additions & 1 deletion tests/core/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform

import torch

from flash import DataModule
Expand Down Expand Up @@ -55,7 +57,10 @@ def test_cpu_count_none():
train_ds = DummyDataset()
# with patch("os.cpu_count", return_value=None), pytest.warns(UserWarning, match="Could not infer"):
dm = DataModule(train_ds, num_workers=None)
assert dm.num_workers > 0
if platform.system() == "Darwin":
assert dm.num_workers == 0
else:
assert dm.num_workers > 0


def test_pipeline():
Expand Down

0 comments on commit 513c951

Please sign in to comment.