Skip to content

Commit ba0680e

Browse files
authored
Merge pull request #227 from swkeemink/joblib
ENH: use joblib instead of multiprocessing
2 parents bcae081 + ace5404 commit ba0680e

File tree

2 files changed

+54
-119
lines changed

2 files changed

+54
-119
lines changed

fissa/core.py

+53-119
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,19 @@
1313
import functools
1414
import glob
1515
import itertools
16-
import multiprocessing
1716
import os.path
1817
import sys
1918
import warnings
2019

21-
import tqdm
22-
from past.builtins import basestring
23-
2420
try:
2521
from collections import abc
2622
except ImportError:
2723
import collections as abc
2824

2925
import numpy as np
26+
import tqdm
27+
from joblib import Parallel, delayed
28+
from past.builtins import basestring
3029
from scipy.io import savemat
3130

3231
from . import deltaf, extraction
@@ -221,17 +220,6 @@ def separate_trials(
221220
return Xsep, Xmatch, Xmixmat, convergence
222221

223222

224-
if sys.version_info < (3, 0):
225-
# Define helper functions which are needed on Python 2.7, which does not
226-
# have multiprocessing.Pool.starmap.
227-
228-
def _extract_wrapper(args):
229-
return extract(*args)
230-
231-
def _separate_wrapper(args):
232-
return separate_trials(*args)
233-
234-
235223
class Experiment:
236224
r"""
237225
FISSA Experiment.
@@ -301,28 +289,30 @@ class Experiment:
301289
302290
.. versionadded:: 1.0.0
303291
304-
ncores_preparation : int or None, default=None
292+
ncores_preparation : int or None, default=-1
305293
The number of parallel subprocesses to use during the data
306294
preparation steps of :meth:`separation_prep`.
307295
These are ROI and neuropil subregion definitions, and extracting
308296
raw signals from TIFFs.
309297
310-
If set to ``None`` (default), the number of processes used will
311-
equal the number of threads on the machine. Note that this
312-
behaviour can, especially for the data preparation step,
298+
If set to ``None`` or ``-1`` (default), the number of processes used
299+
will equal the number of threads on the machine.
300+
If this is set to ``-2``, the number of processes used will be one less
301+
than the number of threads on the machine; etc.
302+
Note that this behaviour can, especially for the data preparation step,
313303
be very memory-intensive.
314304
315-
ncores_separation : int or None, default=None
305+
ncores_separation : int or None, default=-1
316306
The number of parallel subprocesses to use during the signal
317307
separation steps of :meth:`separate`.
318308
The separation steps requires less memory per subprocess than
319309
the preparation steps, and so can be often be set higher than
320310
`ncores_preparation`.
321311
322-
If set to ``None`` (default), the number of processes used will
323-
equal the number of threads on the machine. Note that this
324-
behaviour can, especially for the data preparation step,
325-
be very memory-intensive.
312+
If set to ``None`` or ``-1`` (default), the number of processes used
313+
will equal the number of threads on the machine.
314+
If this is set to ``-2``, the number of processes used will be one less
315+
than the number of threads on the machine; etc.
326316
327317
method : "nmf" or "ica", default="nmf"
328318
Which blind source-separation method to use. Either ``"nmf"``
@@ -492,8 +482,8 @@ def __init__(
492482
max_iter=20000,
493483
tol=1e-4,
494484
max_tries=1,
495-
ncores_preparation=None,
496-
ncores_separation=None,
485+
ncores_preparation=-1,
486+
ncores_separation=-1,
497487
method="nmf",
498488
lowmemory_mode=False,
499489
datahandler=None,
@@ -761,55 +751,17 @@ def separation_prep(self, redo=False):
761751
datahandler=self.datahandler,
762752
)
763753

764-
# Check whether we should use multiprocessing
765-
use_multiprocessing = (
766-
self.ncores_preparation is None or self.ncores_preparation > 1
767-
)
768-
769754
# check whether we should show progress bars
770755
disable_progressbars = self.verbosity != 1
771756

772-
# Do the extraction
773-
if use_multiprocessing and sys.version_info < (3, 0):
774-
# define pool
775-
pool = multiprocessing.Pool(self.ncores_preparation)
776-
# run extraction
777-
outputs = list(
778-
pool.map(
779-
_extract_wrapper,
780-
tqdm.tqdm(
781-
zip(
782-
self.images,
783-
self.rois,
784-
itertools.repeat(self.nRegions, len(self.images)),
785-
itertools.repeat(self.expansion, len(self.images)),
786-
itertools.repeat(self.datahandler, len(self.images)),
787-
),
788-
total=self.nTrials,
789-
desc="Extracting traces",
790-
disable=disable_progressbars,
791-
),
792-
)
793-
)
794-
pool.close()
795-
pool.join()
796-
797-
elif use_multiprocessing:
798-
with multiprocessing.Pool(self.ncores_preparation) as pool:
799-
# run extraction
800-
outputs = list(
801-
pool.starmap(
802-
_extract_cfg,
803-
tqdm.tqdm(
804-
zip(self.images, self.rois),
805-
total=self.nTrials,
806-
desc="Extracting traces",
807-
disable=disable_progressbars,
808-
),
809-
)
810-
)
757+
# Check how many workers to spawn.
758+
# Map the behaviour of ncores=None to one job per CPU core, like for
759+
# multiprocessing.Pool(processes=None). With joblib, this is
760+
# joblib.Parallel(n_jobs=-1) instead.
761+
n_jobs = -1 if self.ncores_preparation is None else self.ncores_preparation
811762

812-
else:
763+
if 0 <= n_jobs <= 1:
764+
# Don't use multiprocessing
813765
outputs = [
814766
_extract_cfg(*args)
815767
for args in tqdm.tqdm(
@@ -819,6 +771,17 @@ def separation_prep(self, redo=False):
819771
disable=disable_progressbars,
820772
)
821773
]
774+
else:
775+
# Use multiprocessing
776+
outputs = Parallel(n_jobs=n_jobs, backend="threading")(
777+
delayed(_extract_cfg)(image, roi_stack)
778+
for image, roi_stack in tqdm.tqdm(
779+
zip(self.images, self.rois),
780+
total=self.nTrials,
781+
desc="Extracting traces",
782+
disable=disable_progressbars,
783+
)
784+
)
822785

823786
# get number of cells
824787
nCell = len(outputs[0][1])
@@ -955,58 +918,18 @@ def separate(self, redo_prep=False, redo_sep=False):
955918
verbosity=self.verbosity - 2,
956919
)
957920

958-
# Check whether we should use multiprocessing
959-
use_multiprocessing = (
960-
self.ncores_separation is None or self.ncores_separation > 1
961-
)
962-
963921
# check whether we should show progress bars
964922
disable_progressbars = self.verbosity != 1
965923

924+
# Check how many workers to spawn.
925+
# Map the behaviour of ncores=None to one job per CPU core, like for
926+
# multiprocessing.Pool(processes=None). With joblib, this is
927+
# joblib.Parallel(n_jobs=-1) instead.
928+
n_jobs = -1 if self.ncores_separation is None else self.ncores_separation
929+
966930
# Do the extraction
967-
if use_multiprocessing and sys.version_info < (3, 0):
968-
# define pool
969-
pool = multiprocessing.Pool(self.ncores_separation)
970-
971-
# run separation
972-
outputs = list(
973-
pool.map(
974-
_separate_wrapper,
975-
tqdm.tqdm(
976-
zip(
977-
self.raw,
978-
range(n_roi),
979-
itertools.repeat(self.alpha, n_roi),
980-
itertools.repeat(self.max_iter, n_roi),
981-
itertools.repeat(self.tol, n_roi),
982-
itertools.repeat(self.max_tries, n_roi),
983-
itertools.repeat(self.method, n_roi),
984-
itertools.repeat(self.verbosity, n_roi),
985-
),
986-
total=self.nCell,
987-
desc="Separating data",
988-
disable=disable_progressbars,
989-
),
990-
)
991-
)
992-
pool.close()
993-
pool.join()
994-
995-
elif use_multiprocessing:
996-
with multiprocessing.Pool(self.ncores_separation) as pool:
997-
# run separation
998-
outputs = list(
999-
pool.starmap(
1000-
_separate_cfg,
1001-
tqdm.tqdm(
1002-
zip(self.raw, range(n_roi)),
1003-
total=self.nCell,
1004-
desc="Separating data",
1005-
disable=disable_progressbars,
1006-
),
1007-
)
1008-
)
1009-
else:
931+
if 0 <= n_jobs <= 1:
932+
# Don't use multiprocessing
1010933
outputs = [
1011934
_separate_cfg(X, roi_label=i)
1012935
for i, X in tqdm.tqdm(
@@ -1016,6 +939,17 @@ def separate(self, redo_prep=False, redo_sep=False):
1016939
disable=disable_progressbars,
1017940
)
1018941
]
942+
else:
943+
# Use multiprocessing
944+
outputs = Parallel(n_jobs=n_jobs, backend="threading")(
945+
delayed(_separate_cfg)(X, i)
946+
for i, X in tqdm.tqdm(
947+
enumerate(self.raw),
948+
total=self.nCell,
949+
desc="Separating data",
950+
disable=disable_progressbars,
951+
)
952+
)
1019953

1020954
# Define output shape as an array of objects shaped (n_roi, n_trial)
1021955
sep = np.empty((n_roi, n_trial), dtype=object)

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
future>=0.16.0
2+
joblib>=0.14.1
23
numpy>=1.13.0
34
Pillow>=4.3.0
45
read-roi>=1.5.0; python_version>='3.0'

0 commit comments

Comments
 (0)