Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve LISI multiprocessing specification #301

Merged
merged 17 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 24 additions & 51 deletions scib/metrics/lisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def ilisi_graph(
type_=None,
subsample=None,
scale=True,
multiprocessing=None,
nodes=None,
n_cores=1,
verbose=False
):
"""Integration LISI (iLISI) score
Expand All @@ -75,10 +74,7 @@ def ilisi_graph(
:param subsample: Percentage of observations (integer between 0 and 100)
to which lisi scoring should be subsampled
:param scale: scale output values between 0 and 1 (True/False)
:param multiprocessing: parallel computation of LISI scores, if None, no parallelisation
via multiprocessing is performed
:param nodes: number of nodes (i.e. CPUs to use for multiprocessing); ignored, if
multiprocessing is set to None
:param n_cores: number of cores (i.e. CPUs or CPU cores to use for multiprocessing)
:return: Median of iLISI scores per batch labels
"""

Expand All @@ -92,8 +88,7 @@ def ilisi_graph(
n_neighbors=k0,
perplexity=None,
subsample=subsample,
multiprocessing=multiprocessing,
nodes=nodes,
n_cores=n_cores,
verbose=verbose
)

Expand All @@ -115,8 +110,7 @@ def clisi_graph(
type_=None,
subsample=None,
scale=True,
multiprocessing=None,
nodes=None,
n_cores=1,
verbose=False
):
"""Cell-type LISI (cLISI) score
Expand All @@ -137,10 +131,7 @@ def clisi_graph(
:param subsample: Percentage of observations (integer between 0 and 100)
to which lisi scoring should be subsampled
:param scale: scale output values between 0 and 1 (True/False)
:param multiprocessing: parallel computation of LISI scores, if None, no parallelisation
via multiprocessing is performed
:param nodes: number of nodes (i.e. CPUs to use for multiprocessing); ignored, if
multiprocessing is set to None
:param n_cores: number of cores (i.e. CPUs or CPU cores to use for multiprocessing)
:return: Median of cLISI scores per cell type labels
"""

Expand All @@ -156,8 +147,7 @@ def clisi_graph(
n_neighbors=k0,
perplexity=None,
subsample=subsample,
multiprocessing=multiprocessing,
nodes=nodes,
n_cores=n_cores,
verbose=verbose
)

Expand Down Expand Up @@ -191,8 +181,7 @@ def lisi_graph_py(
n_neighbors=90,
perplexity=None,
subsample=None,
multiprocessing=None,
nodes=None,
n_cores=1,
verbose=False
):
"""
Expand Down Expand Up @@ -239,20 +228,6 @@ def lisi_graph_py(
print(connectivities.data[large_enough == False])
connectivities.data[large_enough == False] = 3e-308

# define number of chunks
n_chunks = 1

if multiprocessing is not None:
# set up multiprocessing
if nodes is None:
# take all but one CPU and 1 CPU, if there's only 1 CPU.
n_cpu = mp.cpu_count()
n_processes = np.max([n_cpu, np.ceil(n_cpu / 2)]).astype('int')
else:
n_processes = nodes
# update numbr of chunks
n_chunks = n_processes

# temporary file
tmpdir = tempfile.TemporaryDirectory(prefix="lisi_")
dir_path = tmpdir.name + '/'
Expand All @@ -269,8 +244,16 @@ def lisi_graph_py(
cpp_file_path = root / 'knn_graph/knn_graph.o' # create POSIX path to file to execute compiled cpp-code
# comment: POSIX path needs to be converted to string - done below with 'as_posix()'
# create evenly split chunks if n_obs is divisible by n_chunks (doesn't really make sense on 2nd thought)
n_splits = n_chunks - 1
args_int = [cpp_file_path.as_posix(), mtx_file_path, dir_path, str(n_neighbors), str(n_splits), str(subset)]
args_int = [
cpp_file_path.as_posix(),
mtx_file_path,
dir_path,
str(n_neighbors),
str(n_cores - 1), # number of splits
str(subset)
]
if verbose:
print(f'call {" ".join(args_int)}')
try:
subprocess.run(args_int)
except Exception as e:
Expand All @@ -281,13 +264,12 @@ def lisi_graph_py(
if verbose:
print("LISI score estimation")

# do the simpson call
if multiprocessing is not None:
if n_cores > 1:

if verbose:
print(f"{n_processes} processes started.")
pool = mp.Pool(processes=n_processes)
count = np.arange(0, n_processes)
print(f"{n_cores} processes started.")
pool = mp.Pool(processes=n_cores)
count = np.arange(0, n_cores)

# create argument list for each worker
results = pool.starmap(
Expand All @@ -302,27 +284,20 @@ def lisi_graph_py(
pool.close()
pool.join()

simpson_est_batch = 1 / np.concatenate(results)
simpson_estimate_batch = np.concatenate(results)

else:
simpson_estimate_batch = compute_simpson_index_graph(
input_path=dir_path,
batch_labels=batch,
n_batches=n_batches,
perplexity=perplexity,
n_neighbors=n_neighbors,
chunk_no=None
n_neighbors=n_neighbors
)
simpson_est_batch = 1 / simpson_estimate_batch

tmpdir.cleanup()

# extract results
d = {batch_key: simpson_est_batch}

lisi_estimate = pd.DataFrame(data=d, index=np.arange(0, len(simpson_est_batch)))

return lisi_estimate
return 1 / simpson_estimate_batch


# LISI core functions
Expand Down Expand Up @@ -427,8 +402,6 @@ def compute_simpson_index_graph(
P = np.zeros(n_neighbors)
logU = np.log(perplexity)

if chunk_no is None:
chunk_no = 0
# check if the target file is not empty
if os.stat(input_path + '_indices_' + str(chunk_no) + '.txt').st_size == 0:
print("File has no entries. Doing nothing.")
Expand Down
6 changes: 4 additions & 2 deletions scib/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def metrics(
ilisi_=False,
clisi_=False,
subsample=0.5,
n_cores=1,
type_=None,
verbose=False,
):
Expand Down Expand Up @@ -293,6 +294,7 @@ def metrics(
whether to compute iLISI using :func:`~scib.metrics.ilisi_graph`
:param subsample:
subsample fraction for LISI scores
:param n_cores: number of cores to be used for LISI functions
:param `type_`:
one of 'full', 'embed' or 'knn' (used for kBET and LISI scores)
"""
Expand Down Expand Up @@ -456,7 +458,7 @@ def metrics(
type_=type_,
subsample=subsample * 100,
scale=True,
multiprocessing=True,
n_cores=n_cores,
verbose=verbose
)
else:
Expand All @@ -470,7 +472,7 @@ def metrics(
type_=type_,
subsample=subsample * 100,
scale=True,
multiprocessing=True,
n_cores=n_cores,
verbose=verbose
)
else:
Expand Down
10 changes: 6 additions & 4 deletions tests/metrics/test_clisi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from tests.common import *


def test_clisi_full(adata):
score = scib.me.clisi_graph(
adata,
batch_key='batch',
label_key='celltype',
scale=True,
type_='full'
type_='full',
verbose=True
)

LOGGER.info(f"score: {score}")
Expand All @@ -21,7 +21,8 @@ def test_clisi_embed(adata_neighbors):
batch_key='batch',
label_key='celltype',
scale=True,
type_='embed'
type_='embed',
verbose=True
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.982, diff=1e-2)
Expand All @@ -33,7 +34,8 @@ def test_clisi_knn(adata_neighbors):
batch_key='batch',
label_key='celltype',
scale=True,
type_='graph'
type_='graph',
verbose=True
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.982, diff=1e-2)
9 changes: 6 additions & 3 deletions tests/metrics/test_ilisi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ def test_ilisi_full(adata):
adata,
batch_key='batch',
scale=True,
type_='full'
type_='full',
verbose=True
)

LOGGER.info(f"score: {score}")
Expand All @@ -19,7 +20,8 @@ def test_ilisi_embed(adata_neighbors):
adata_neighbors,
batch_key='batch',
scale=True,
type_='embed'
type_='embed',
verbose=True
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.238, diff=1e-2)
Expand All @@ -30,7 +32,8 @@ def test_ilisi_knn(adata_neighbors):
adata_neighbors,
batch_key='batch',
scale=True,
type_='graph'
type_='graph',
verbose=True
)
LOGGER.info(f"score: {score}")
assert_near_exact(score, 0.238, diff=1e-2)