diff --git a/command_line/image_average.py b/command_line/image_average.py index b1bbe0ba2..7bec46548 100644 --- a/command_line/image_average.py +++ b/command_line/image_average.py @@ -10,6 +10,7 @@ import copy import sys from builtins import range +import numpy as np import libtbx.load_env from libtbx import easy_mp, option_parser @@ -284,7 +285,7 @@ def run(argv=None): iterable = iterable[command_line.options.skip_images :] if ( command_line.options.num_images_max is not None - and command_line.options.num_images_max < iterable + and command_line.options.num_images_max < len(iterable) ): iterable = iterable[: command_line.options.num_images_max] assert len(iterable) >= 2, "Need more than one image to average" @@ -308,14 +309,39 @@ def run(argv=None): # chop the list into pieces, depending on rank. This assigns each process # events such that the get every Nth event where N is the number of processes iterable = [i for n, i in enumerate(iterable) if (n + rank) % size == 0] - results = [worker(iterable)] - results = comm.gather(results, root=0) + r_nfail, r_nmemb, r_max_img, r_sum_distance, r_sum_img, r_ssq_img, r_sum_wavelength = worker( + iterable + ) + + nfail = np.array([0]) + nmemb = np.array([0]) + sum_distance = np.array([0.0]) + sum_wavelength = np.array([0.0]) + comm.Reduce(np.array([r_nfail]), nfail) + comm.Reduce(np.array([r_nmemb]), nmemb) + comm.Reduce(np.array([r_sum_distance]), sum_distance) + comm.Reduce(np.array([r_sum_wavelength]), sum_wavelength) + nfail = int(nfail[0]) + nmemb = int(nmemb) + sum_distance = float(sum_distance[0]) + sum_wavelength = float(sum_wavelength[0]) + + def reduce_image(data, op=MPI.SUM): + result = [] + for panel_data in data: + panel_data = panel_data.as_numpy_array() + reduced_data = np.zeros(panel_data.shape).astype(panel_data.dtype) + comm.Reduce(panel_data, reduced_data, op=op) + result.append(flex.double(reduced_data)) + return result + + max_img = reduce_image(r_max_img, MPI.MAX) + sum_img = reduce_image(r_sum_img) + ssq_img = reduce_image(r_ssq_img) + if rank != 0: return - results_set = [] - for r in results: - results_set.extend(r) - results = results_set + avg_img = tuple(s / nmemb for s in sum_img) else: if command_line.options.nproc == 1: results = [worker(iterable)] @@ -325,38 +351,38 @@ def run(argv=None): func=worker, iterable=iterable, processes=command_line.options.nproc ) - nfail = 0 - nmemb = 0 - for ( - i, - ( - r_nfail, - r_nmemb, - r_max_img, - r_sum_distance, - r_sum_img, - r_ssq_img, - r_sum_wavelength, - ), - ) in enumerate(results): - nfail += r_nfail - nmemb += r_nmemb - if i == 0: - max_img = r_max_img - sum_distance = r_sum_distance - sum_img = r_sum_img - ssq_img = r_ssq_img - sum_wavelength = r_sum_wavelength - else: - for p in range(len(sum_img)): - sel = (r_max_img[p] > max_img[p]).as_1d() - max_img[p].set_selected(sel, r_max_img[p].select(sel)) + nfail = 0 + nmemb = 0 + for ( + i, + ( + r_nfail, + r_nmemb, + r_max_img, + r_sum_distance, + r_sum_img, + r_ssq_img, + r_sum_wavelength, + ), + ) in enumerate(results): + nfail += r_nfail + nmemb += r_nmemb + if i == 0: + max_img = r_max_img + sum_distance = r_sum_distance + sum_img = r_sum_img + ssq_img = r_ssq_img + sum_wavelength = r_sum_wavelength + else: + for p in range(len(sum_img)): + sel = (r_max_img[p] > max_img[p]).as_1d() + max_img[p].set_selected(sel, r_max_img[p].select(sel)) - sum_img[p] += r_sum_img[p] - ssq_img[p] += r_ssq_img[p] + sum_img[p] += r_sum_img[p] + ssq_img[p] += r_ssq_img[p] - sum_distance += r_sum_distance - sum_wavelength += r_sum_wavelength + sum_distance += r_sum_distance + sum_wavelength += r_sum_wavelength # Early exit if no statistics were accumulated. if command_line.options.verbose: diff --git a/tests/command_line/test_average.py b/tests/command_line/test_average.py index 2900875d6..1e60488e8 100644 --- a/tests/command_line/test_average.py +++ b/tests/command_line/test_average.py @@ -6,18 +6,29 @@ import dxtbx +import pytest + + +@pytest.mark.parametrize("use_mpi", [True, False]) +def test_average(dials_regression, tmpdir, use_mpi): + # Only allow MPI tests if we've got MPI capabilities + if use_mpi: + pytest.importorskip("mpi4py") -def test_average(dials_regression, tmpdir): data = os.path.join( dials_regression, "image_examples", "SACLA_MPCCD_Cheetah", "run266702-0-subset.h5", ) + if use_mpi: + command = "mpirun" + mpargs = "-n 2 dxtbx.image_average".split() + else: + command = "dxtbx.image_average" + mpargs = "-n 2".split() result = procrunner.run( - ["dxtbx.image_average"] - + "-v -a avg.cbf -s stddev.cbf -m max.cbf".split() - + [data], + [command] + mpargs + "-v -a avg.cbf -s stddev.cbf -m max.cbf".split() + [data], working_directory=tmpdir, ) assert not result.returncode and not result.stderr