diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 2d6d11ae0..91ad8c732 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -23,3 +23,4 @@ dependencies: - cython - gmsh - pyvkfft +- mpi4py diff --git a/pytential/qbx/__init__.py b/pytential/qbx/__init__.py index 03ee21f73..9cf48e1e2 100644 --- a/pytential/qbx/__init__.py +++ b/pytential/qbx/__init__.py @@ -76,6 +76,18 @@ class _not_provided: # noqa: N801 pass +class _LevelToOrderWrapper: + """ + Helper functor to convert a constant integer fmm order into a pickable and + callable object. + """ + def __init__(self, fmm_order): + self.fmm_order = fmm_order + + def __call__(self, kernel, kernel_args, tree, level): + return self.fmm_order + + class QBXLayerPotentialSource(LayerPotentialSourceBase): """A source discretization for a QBX layer potential. @@ -131,7 +143,8 @@ def __init__( order to be used on a given *level* of *tree* with *kernel*, where *kernel* is the :class:`sumpy.kernel.Kernel` being evaluated, and *kernel_args* is a set of *(key, value)* tuples with evaluated - kernel arguments. May not be given if *fmm_order* is given. + kernel arguments. May not be given if *fmm_order* is given. If used in + the distributed setting, this argument must be pickable. :arg fmm_backend: a string denoting the desired FMM backend to use, either `"sumpy"` or `"fmmlib"`. Only used if *fmm_order* or *fmm_level_to_order* are provided. @@ -204,9 +217,8 @@ def __init__( else: assert isinstance(fmm_order, int) and not isinstance(fmm_order, bool) - # pylint: disable-next=function-redefined - def fmm_level_to_order(kernel, kernel_args, tree, level): - return fmm_order + fmm_level_to_order = _LevelToOrderWrapper(fmm_order) + assert isinstance(fmm_level_to_order, bool) or callable(fmm_level_to_order) if _max_leaf_refine_weight is None: diff --git a/pytential/qbx/distributed.py b/pytential/qbx/distributed.py new file mode 100644 index 000000000..8cde810d4 --- /dev/null +++ b/pytential/qbx/distributed.py @@ -0,0 +1,826 @@ +__copyright__ = "Copyright (C) 2022 Hao Gao" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from pytential.qbx import QBXLayerPotentialSource +from arraycontext import PyOpenCLArrayContext, unflatten +from typing import Any, Dict +from dataclasses import dataclass +import numpy as np +import pyopencl as cl +from boxtree.tools import DeviceDataRecord +from pytools import memoize_method + + +@dataclass +class GlobalQBXFMMGeometryData: + """A trimmed-down version of :class:`QBXFMMGeometryData` to be broadcasted for + the distributed implementation. Each rank should have the same global geometry + data. + """ + global_traversal: Any + centers: Any + expansion_radii: Any + global_qbx_centers: Any + qbx_center_to_target_box: Any + non_qbx_box_target_lists: Any + center_to_tree_targets: Any + + +class LocalQBXFMMGeometryData(DeviceDataRecord): + """A subset of the global geometry data used by each rank to calculate potentials + using FMM. Each rank should have its own version of the local geometry data. + """ + def non_qbx_box_target_lists(self): + return self._non_qbx_box_target_lists + + def traversal(self): + return self.local_trav + + def tree(self): + return self.traversal().tree + + def centers(self): + return self._local_centers + + @property + def ncenters(self): + return self._local_centers.shape[1] + + def global_qbx_centers(self): + return self._global_qbx_centers + + def expansion_radii(self): + return self._expansion_radii + + def qbx_center_to_target_box(self): + return self._local_qbx_center_to_target_box + + def center_to_tree_targets(self): + return self._local_center_to_tree_targets + + def qbx_targets(self): + return self._qbx_targets + + def qbx_center_to_target_box_source_level(self, source_level): + return self._qbx_center_to_target_box_source_level[source_level] + + @memoize_method + def build_rotation_classes_lists(self): + with cl.CommandQueue(self.cl_context) as queue: + trav = self.traversal().to_device(queue) + tree = self.tree().to_device(queue) + + from boxtree.rotation_classes import RotationClassesBuilder + return RotationClassesBuilder(self.cl_context)( + queue, trav, tree)[0].get(queue) + + def eval_qbx_targets(self): + return self.qbx_targets() + + @memoize_method + def m2l_rotation_lists(self): + return self.build_rotation_classes_lists().from_sep_siblings_rotation_classes + + @memoize_method + def m2l_rotation_angles(self): + return (self + .build_rotation_classes_lists() + .from_sep_siblings_rotation_class_to_angle) + + def src_idx_all_ranks(self): + return self._src_idx_all_ranks + + +# {{{ Traversal builder + +class QBXFMMGeometryDataTraversalBuilder: + # Could we use QBXFMMGeometryDataCodeContainer instead? + def __init__(self, context, well_sep_is_n_away=1, from_sep_smaller_crit=None, + _from_sep_smaller_min_nsources_cumul=0): + from boxtree.traversal import FMMTraversalBuilder + self.traversal_builder = FMMTraversalBuilder( + context, + well_sep_is_n_away=well_sep_is_n_away, + from_sep_smaller_crit=from_sep_smaller_crit) + self._from_sep_smaller_min_nsources_cumul = ( + _from_sep_smaller_min_nsources_cumul) + + def __call__(self, queue, tree, **kwargs): + trav, evt = self.traversal_builder( + queue, tree, + _from_sep_smaller_min_nsources_cumul=( + self._from_sep_smaller_min_nsources_cumul), + **kwargs) + + return trav, evt + +# }}} + + +def broadcast_global_geometry_data( + comm, actx, traversal_builder, global_geometry_data): + """Broadcasts useful fields of global geometry data from the root rank to the + worker ranks, so that each rank can form local geometry data independently. + + This function should be called collectively by all ranks in `comm`. + + :arg comm: MPI communicator. + :arg traversal_builder: a :class:`QBXFMMGeometryDataTraversalBuilder` object, + used for constructing the global traversal object from the broadcasted global + tree. This argument is significant on all ranks. + :arg global_geometry_data: an object of :class:`ToHostTransferredGeoDataWrapper`, + the global geometry data on host memory. This argument is only significant on + the root rank. + :returns: a :class:`GlobalQBXFMMGeometryData` object on each worker + rank, representing the broadcasted subset of the global geometry data, used + for constructing the local geometry data independently. See + :func:`compute_local_geometry_data`. + """ + mpi_rank = comm.Get_rank() + queue = actx.queue + + global_traversal = None + global_tree = None + centers = None + expansion_radii = None + global_qbx_centers = None + qbx_center_to_target_box = None + non_qbx_box_target_lists = None + center_to_tree_targets = None + + # {{{ Broadcast necessary fields from the root rank to worker ranks + + if mpi_rank == 0: + global_traversal = global_geometry_data.traversal() + global_tree = global_traversal.tree + + centers = global_geometry_data.centers() + expansion_radii = global_geometry_data.expansion_radii() + global_qbx_centers = global_geometry_data.global_qbx_centers() + qbx_center_to_target_box = global_geometry_data.qbx_center_to_target_box() + non_qbx_box_target_lists = global_geometry_data.non_qbx_box_target_lists() + center_to_tree_targets = global_geometry_data.center_to_tree_targets() + + global_tree = comm.bcast(global_tree, root=0) + centers = comm.bcast(centers, root=0) + expansion_radii = comm.bcast(expansion_radii, root=0) + global_qbx_centers = comm.bcast(global_qbx_centers, root=0) + qbx_center_to_target_box = comm.bcast(qbx_center_to_target_box, root=0) + non_qbx_box_target_lists = comm.bcast(non_qbx_box_target_lists, root=0) + center_to_tree_targets = comm.bcast(center_to_tree_targets, root=0) + + # }}} + + # {{{ Each rank constructs the global traversal object independently + + global_tree_dev = global_tree.to_device(queue).with_queue(queue) + if mpi_rank != 0: + global_traversal, _ = traversal_builder(queue, global_tree_dev) + + if global_tree_dev.targets_have_extent: + global_traversal = global_traversal.merge_close_lists(queue) + + global_traversal = global_traversal.get(queue) + + # }}} + + return GlobalQBXFMMGeometryData( + global_traversal, + centers, + expansion_radii, + global_qbx_centers, + qbx_center_to_target_box, + non_qbx_box_target_lists, + center_to_tree_targets) + + +def compute_local_geometry_data( + actx, comm, global_geometry_data, boxes_time, + traversal_builder): + """Compute the local geometry data of the current rank from the global geometry + data. + + :arg comm: MPI communicator. + :arg global_geometry_data: Global geometry data from which the local geometry + data is generated. + :arg boxes_time: Predicated cost of each box. Used by partitioning to improve + load balancing. + :arg traversal_builder: Used to construct local traversal. + """ + queue = actx.queue + + global_traversal = global_geometry_data.global_traversal + global_tree = global_traversal.tree + centers = global_geometry_data.centers + ncenters = len(centers[0]) + expansion_radii = global_geometry_data.expansion_radii + global_qbx_centers = global_geometry_data.global_qbx_centers + qbx_center_to_target_box = global_geometry_data.qbx_center_to_target_box + non_qbx_box_target_lists = global_geometry_data.non_qbx_box_target_lists + center_to_tree_targets = global_geometry_data.center_to_tree_targets + + # {{{ Generate local tree and local traversal + + from boxtree.distributed.partition import partition_work + responsible_boxes_list = partition_work(boxes_time, global_traversal, comm) + + from boxtree.distributed.local_tree import generate_local_tree + local_tree, src_idx, tgt_idx = generate_local_tree( + queue, global_traversal, responsible_boxes_list, comm) + + src_idx_all_ranks = comm.gather(src_idx, root=0) + tgt_idx_all_ranks = comm.gather(tgt_idx, root=0) + + from boxtree.distributed.local_traversal import generate_local_travs + local_trav = generate_local_travs( + queue, local_tree, traversal_builder, + # TODO: get whether to merge close lists from root instead of + # hard-coding? + merge_close_lists=True).get(queue=queue) + + # }}} + + # {{{ Form non_qbx_box_target_lists + + from boxtree.distributed.local_tree import LocalTreeGeneratorCodeContainer + code = LocalTreeGeneratorCodeContainer( + queue.context, + global_tree.dimensions, + global_tree.particle_id_dtype, + global_tree.coord_dtype) + + box_target_starts = cl.array.to_device( + queue, non_qbx_box_target_lists.box_target_starts) + box_target_counts_nonchild = cl.array.to_device( + queue, non_qbx_box_target_lists.box_target_counts_nonchild) + nfiltered_targets = non_qbx_box_target_lists.nfiltered_targets + targets = non_qbx_box_target_lists.targets + + particle_mask = cl.array.zeros( + queue, (nfiltered_targets,), dtype=global_tree.particle_id_dtype) + + responsible_boxes_mask = np.zeros(global_tree.nboxes, dtype=np.int8) + responsible_boxes_mask[responsible_boxes_list] = 1 + responsible_boxes_mask = cl.array.to_device(queue, responsible_boxes_mask) + + code.particle_mask_kernel()( + responsible_boxes_mask, + box_target_starts, + box_target_counts_nonchild, + particle_mask) + + particle_scan = cl.array.empty( + queue, (nfiltered_targets + 1,), + dtype=global_tree.particle_id_dtype) + particle_scan[0] = 0 + code.mask_scan_kernel()(particle_mask, particle_scan) + + local_box_target_starts = particle_scan[box_target_starts] + + lobal_box_target_counts_all_zeros = cl.array.zeros( + queue, (global_tree.nboxes,), dtype=global_tree.particle_id_dtype) + + local_box_target_counts_nonchild = cl.array.if_positive( + responsible_boxes_mask, + box_target_counts_nonchild, + lobal_box_target_counts_all_zeros) + + local_nfiltered_targets = int(particle_scan[-1].get(queue)) + + particle_mask = particle_mask.get().astype(bool) + particle_mask_all_ranks = comm.gather(particle_mask, root=0) + local_targets = np.empty((global_tree.dimensions,), dtype=object) + for idimension in range(global_tree.dimensions): + local_targets[idimension] = targets[idimension][particle_mask] + + from boxtree.tree import FilteredTargetListsInTreeOrder + non_qbx_box_target_lists = FilteredTargetListsInTreeOrder( + nfiltered_targets=local_nfiltered_targets, + box_target_starts=local_box_target_starts.get(), + box_target_counts_nonchild=local_box_target_counts_nonchild.get(), + targets=local_targets, + unfiltered_from_filtered_target_indices=None) + + # }}} + + tgt_mask = np.zeros((global_tree.ntargets,), dtype=bool) + tgt_mask[tgt_idx] = True + + tgt_mask_user_order = tgt_mask[global_tree.sorted_target_ids] + centers_mask = tgt_mask_user_order[:ncenters] + centers_scan = np.empty( + (ncenters + 1,), dtype=global_tree.particle_id_dtype) + centers_scan[1:] = np.cumsum( + centers_mask.astype(global_tree.particle_id_dtype)) + centers_scan[0] = 0 + + # {{{ local centers + + nlocal_centers = np.sum(centers_mask.astype(np.int32)) + centers_dims = centers.shape[0] + local_centers = np.empty( + (centers_dims, nlocal_centers), dtype=centers[0].dtype) + for idims in range(centers_dims): + local_centers[idims, :] = centers[idims][centers_mask] + + # }}} + + # {{{ local global_qbx_centers + + local_global_qbx_centers = centers_scan[ + global_qbx_centers[centers_mask[global_qbx_centers]]] + + # }}} + + # {{{ local expansion_radii + + local_expansion_radii = expansion_radii[centers_mask] + + # }}} + + # {{{ local qbx_center_to_target_box + + # Transform local qbx_center_to_target_box to global indexing + local_qbx_center_to_target_box = global_traversal.target_boxes[ + qbx_center_to_target_box[centers_mask]] + + # Transform local_qbx_center_to_target_box to local target_boxes indexing + global_boxes_to_target_boxes = np.ones( + (global_tree.nboxes,), dtype=local_tree.particle_id_dtype) + # make sure accessing invalid position raises an error + global_boxes_to_target_boxes *= -1 + global_boxes_to_target_boxes[local_trav.target_boxes] = \ + np.arange(local_trav.target_boxes.shape[0]) + local_qbx_center_to_target_box = \ + global_boxes_to_target_boxes[local_qbx_center_to_target_box] + + # }}} + + # {{{ local_qbx_targets and local center_to_tree_targets + + starts = center_to_tree_targets.starts + lists = center_to_tree_targets.lists + local_starts = np.empty((nlocal_centers + 1,), dtype=starts.dtype) + local_lists = np.empty(lists.shape, dtype=lists.dtype) + + qbx_target_mask = np.zeros((global_tree.ntargets,), dtype=bool) + current_start = 0 # index into local_lists + ilocal_center = 0 + local_starts[0] = 0 + + for icenter in range(ncenters): + # skip the current center if the current rank is not responsible for + # processing it + if not centers_mask[icenter]: + continue + + current_center_targets = lists[starts[icenter]:starts[icenter + 1]] + qbx_target_mask[current_center_targets] = True + current_stop = current_start + starts[icenter + 1] - starts[icenter] + local_starts[ilocal_center + 1] = current_stop + local_lists[current_start:current_stop] = \ + lists[starts[icenter]:starts[icenter + 1]] + + current_start = current_stop + ilocal_center += 1 + + qbx_target_mask_all_ranks = comm.gather(qbx_target_mask, root=0) + + local_lists = local_lists[:current_start] + + qbx_target_scan = np.empty( + (global_tree.ntargets + 1,), dtype=lists.dtype + ) + qbx_target_scan[0] = 0 + qbx_target_scan[1:] = np.cumsum(qbx_target_mask.astype(lists.dtype)) + nlocal_qbx_target = qbx_target_scan[-1] + + local_qbx_targets = np.empty( + (global_tree.dimensions, nlocal_qbx_target), + dtype=global_tree.targets[0].dtype + ) + for idim in range(global_tree.dimensions): + local_qbx_targets[idim, :] = global_tree.targets[idim][qbx_target_mask] + + local_lists = qbx_target_scan[local_lists] + + from pytential.qbx.geometry import CenterToTargetList + local_center_to_tree_targets = CenterToTargetList( + starts=local_starts, + lists=local_lists) + + # }}} + + # }}} + + # {{{ Construct qbx_center_to_target_box_source_level + + # This is modified from pytential.geometry.QBXFMMGeometryData. + # qbx_center_to_target_box_source_level but on host using Numpy instead of + # PyOpenCL. + + tree = local_trav.tree + + qbx_center_to_target_box_source_level = np.empty( + (tree.nlevels,), dtype=object) + + for source_level in range(tree.nlevels): + sep_smaller = local_trav.from_sep_smaller_by_level[source_level] + + target_box_to_target_box_source_level = np.empty( + len(local_trav.target_boxes), + dtype=tree.box_id_dtype) + target_box_to_target_box_source_level.fill(-1) + target_box_to_target_box_source_level[sep_smaller.nonempty_indices] = ( + np.arange(sep_smaller.num_nonempty_lists, + dtype=tree.box_id_dtype)) + + qbx_center_to_target_box_source_level[source_level] = ( + target_box_to_target_box_source_level[ + local_qbx_center_to_target_box]) + + # }}} + + return LocalQBXFMMGeometryData( + cl_context=queue.context, + local_tree=local_tree, + local_trav=local_trav, + _local_centers=local_centers, + _global_qbx_centers=local_global_qbx_centers, + src_idx=src_idx, + tgt_idx=tgt_idx, + _src_idx_all_ranks=src_idx_all_ranks, + _tgt_idx_all_ranks=tgt_idx_all_ranks, + particle_mask=particle_mask_all_ranks, + qbx_target_mask=qbx_target_mask_all_ranks, + _non_qbx_box_target_lists=non_qbx_box_target_lists, + _local_qbx_center_to_target_box=local_qbx_center_to_target_box, + _expansion_radii=local_expansion_radii, + _qbx_targets=local_qbx_targets, + _local_center_to_tree_targets=local_center_to_tree_targets, + _qbx_center_to_target_box_source_level=( + qbx_center_to_target_box_source_level)) + + +def distribute_geo_data(comm, actx, insn, bound_expr, evaluate, + global_geo_data_device): + geo_data_cache = bound_expr._geo_data_cache + + if insn in geo_data_cache: + return geo_data_cache[insn] + + boxes_time = None + global_geo_data = None + + if comm.Get_rank() == 0: + # Use the cost model to estimate execution time for partitioning + from pytential.qbx.cost import AbstractQBXCostModel, QBXCostModel + + # FIXME: If the expansion wrangler is not FMMLib, the argument + # 'uses_pde_expansions' might be different + cost_model = QBXCostModel() + + import warnings + warnings.warn( + "Kernel-specific calibration parameters are not supplied when" + "using distributed FMM.", + stacklevel=2) + # TODO: supply better default calibration parameters + calibration_params = AbstractQBXCostModel.get_unit_calibration_params() + + kernel_args = {} + for arg_name, arg_expr in insn.kernel_arguments.items(): + kernel_args[arg_name] = evaluate(arg_expr) + + boxes_time, _ = cost_model.qbx_cost_per_box( + actx.queue, global_geo_data_device, insn.target_kernels[0], + kernel_args, calibration_params) + boxes_time = boxes_time.get() + + from pytential.qbx.utils import ToHostTransferredGeoDataWrapper + global_geo_data = ToHostTransferredGeoDataWrapper(global_geo_data_device) + + # {{{ Construct a traversal builder + + # NOTE: The distributed implementation relies on building the same traversal + # objects as the one on the root rank. This means here the traversal builder + # should use the same parameters as `QBXFMMGeometryData.traversal`. To make + # it consistent across ranks, we broadcast the parameters here. + + trav_param = None + if comm.Get_rank() == 0: + trav_param = { + "well_sep_is_n_away": + global_geo_data.geo_data.code_getter.build_traversal + .well_sep_is_n_away, + "from_sep_smaller_crit": + global_geo_data.geo_data.code_getter.build_traversal. + from_sep_smaller_crit, + "_from_sep_smaller_min_nsources_cumul": + global_geo_data.geo_data.lpot_source. + _from_sep_smaller_min_nsources_cumul} + trav_param = comm.bcast(trav_param, root=0) + + traversal_builder = QBXFMMGeometryDataTraversalBuilder( + actx.context, + well_sep_is_n_away=trav_param["well_sep_is_n_away"], + from_sep_smaller_crit=trav_param["from_sep_smaller_crit"], + _from_sep_smaller_min_nsources_cumul=trav_param[ + "_from_sep_smaller_min_nsources_cumul"]) + + # }}} + + # {{{ Broadcast the subset of the global geometry data to worker ranks + + global_geo_data = broadcast_global_geometry_data( + comm, actx, traversal_builder, global_geo_data) + + # }}} + + # {{{ Compute the local geometry data from the global geometry data + + if comm.Get_rank() != 0: + boxes_time = np.empty( + global_geo_data.global_traversal.tree.nboxes, dtype=np.float64) + + comm.Bcast(boxes_time, root=0) + + local_geo_data = compute_local_geometry_data( + actx, comm, global_geo_data, boxes_time, traversal_builder) + + # }}} + + geo_data_cache[insn] = (global_geo_data, local_geo_data) + + return global_geo_data, local_geo_data + + +class DistributedQBXLayerPotentialSource(QBXLayerPotentialSource): + def __init__(self, comm, cl_context, *args, **kwargs): + """ + :arg comm: MPI communicator. + :arg cl_context: This argument is necessary because although the root rank + can deduce the CL context from density, worker ranks do not have a valid + density, so we specify there explicitly. + + `*args` and `**kwargs` will be forwarded to + `QBXLayerPotentialSource.__init__` on the root rank. + + Currently, `fmm_backend` has to be set to `"fmmlib"`. + """ + self.comm = comm + self._cl_context = cl_context + + # "_from_sep_smaller_min_nsources_cumul" can only be 0 for the distributed + # implementation. If not, the potential contribution of a list 3 box may be + # computed particle-to-particle instead of using its multipole expansion. + # However, the source particles may not be distributed to the target rank. + if "_from_sep_smaller_min_nsources_cumul" not in kwargs: + kwargs["_from_sep_smaller_min_nsources_cumul"] = 0 + elif kwargs["_from_sep_smaller_min_nsources_cumul"] != 0: + raise ValueError( + "_from_sep_smaller_min_nsources_cumul has to be 0 for the " + "distributed implementation") + + # Only fmmlib is supported + assert kwargs["fmm_backend"] == "fmmlib" + + if self.comm.Get_rank() == 0: + super().__init__(*args, **kwargs) + else: + self.fmm_backend = "fmmlib" + self._use_target_specific_qbx = kwargs.get( + "_use_target_specific_qbx", None) + self.qbx_order = kwargs.get("qbx_order", None) + self.fmm_level_to_order = kwargs.get("fmm_level_to_order", None) + self.expansion_factory = kwargs.get("expansion_factory", None) + + @property + def cl_context(self): + return self._cl_context + + def get_local_fmm_expansion_wrangler_extra_kwargs( + self, actx, src_idx, target_kernels, tree_user_source_ids, arguments, + evaluator): + mpi_rank = self.comm.Get_rank() + + kernel_extra_kwargs = {} + source_extra_kwargs = {} + + if mpi_rank == 0: + kernel_extra_kwargs, source_extra_kwargs = \ + self.get_fmm_expansion_wrangler_extra_kwargs( + actx, target_kernels, tree_user_source_ids, + arguments, evaluator) + + # kernel_extra_kwargs contains information like helmholtz k, which should be + # picklable and cheap to broadcast + kernel_extra_kwargs = self.comm.bcast(kernel_extra_kwargs, root=0) + + # Broadcast the keys in `source_extra_kwargs` to worker ranks + source_arg_names = None + if mpi_rank == 0: + source_arg_names = list(source_extra_kwargs.keys()) + source_arg_names = self.comm.bcast(source_arg_names, root=0) + + for arg_name in source_arg_names: + if arg_name != "dsource_vec": + raise NotImplementedError + + # Broadcast the global source array to worker ranks + global_array_host = None + if mpi_rank == 0: + global_array_host = actx.to_numpy(source_extra_kwargs[arg_name]) + global_array_host = self.comm.bcast(global_array_host, root=0) + + # Compute the local source array independently on each worker rank + local_array_host = np.empty_like(global_array_host) + for idim, global_array_idim in enumerate(global_array_host): + local_array_host[idim] = global_array_idim[src_idx] + + source_extra_kwargs[arg_name] = actx.from_numpy(local_array_host) + + return kernel_extra_kwargs, source_extra_kwargs + + def exec_compute_potential_insn(self, actx, insn, bound_expr, evaluate, + return_timing_data): + extra_args = {} + + # Broadcast whether to use direct evaluation or FMM + use_direct = True + if self.comm.Get_rank() == 0: + use_direct = self.fmm_level_to_order is False + use_direct = self.comm.bcast(use_direct, root=0) + + if use_direct: + func = self.exec_compute_potential_insn_direct + extra_args["return_timing_data"] = return_timing_data + else: + func = self.exec_compute_potential_insn_fmm + extra_args["fmm_driver"] = None + + if self.comm.Get_rank() == 0: + return self._dispatch_compute_potential_insn( + actx, insn, bound_expr, evaluate, func, extra_args) + else: + return func(actx, insn, bound_expr, evaluate, **extra_args) + + def exec_compute_potential_insn_direct(self, *args, **kwargs): + # FIXME: The current implementation executes on a single rank. + import warnings + warnings.warn( + "The distributed implementation does not support direct " + "(non-FMM) evaluation", + stacklevel=2) + + if self.comm.Get_rank() == 0: + return super().exec_compute_potential_insn_direct(*args, **kwargs) + else: + results = [] + timing_data = {} + return results, timing_data + + def exec_compute_potential_insn_fmm(self, actx: PyOpenCLArrayContext, + insn, bound_expr, evaluate, fmm_driver): + """ + :returns: a tuple ``(assignments, extra_outputs)``, where *assignments* + is a list of tuples containing pairs ``(name, value)`` representing + assignments to be performed in the evaluation context. + *extra_outputs* is data that *fmm_driver* may return + (such as timing data), passed through unmodified. + """ + from pytential.qbx import get_flat_strengths_from_densities + from meshmode.discretization import Discretization + + global_geo_data_device = None + output_and_expansion_dtype = None + flat_strengths = [] + + if self.comm.Get_rank() == 0: + target_name_and_side_to_number, target_discrs_and_qbx_sides = ( + self.get_target_discrs_and_qbx_sides(insn, bound_expr)) + + global_geo_data_device = self.qbx_fmm_geometry_data( + bound_expr.places, + insn.source.geometry, + target_discrs_and_qbx_sides) + + global_geo_data, local_geo_data = distribute_geo_data( + self.comm, actx, insn, bound_expr, evaluate, global_geo_data_device) + + tree_indep = self._tree_indep_data_for_wrangler( + target_kernels=insn.target_kernels, + source_kernels=insn.source_kernels) + + user_source_ids = None + if self.comm.Get_rank() == 0: + assert global_geo_data_device is not None + user_source_ids = global_geo_data_device.tree().user_source_ids + + kernel_extra_kwargs, source_extra_kwargs = ( + self.get_local_fmm_expansion_wrangler_extra_kwargs( + actx, local_geo_data.src_idx, + insn.target_kernels + insn.source_kernels, + user_source_ids, insn.kernel_arguments, evaluate)) + + if self.comm.Get_rank() == 0: + flat_strengths = get_flat_strengths_from_densities( + actx, bound_expr.places, evaluate, insn.densities, + dofdesc=insn.source) + + output_and_expansion_dtype = ( + self.get_fmm_output_and_expansion_dtype( + insn.source_kernels, flat_strengths[0])) + else: + flat_strengths = [actx.empty(0, dtype=int)] + + output_and_expansion_dtype = self.comm.bcast( + output_and_expansion_dtype, root=0) + + from pytential.qbx.fmmlib import ( + QBXFMMLibExpansionWrangler, DistributedQBXFMMLibExpansionWrangler) + + wrangler_cls = None + if tree_indep.wrangler_cls == QBXFMMLibExpansionWrangler: + wrangler_cls = DistributedQBXFMMLibExpansionWrangler + else: + raise NotImplementedError + + wrangler = wrangler_cls( + self._cl_context, self.comm, tree_indep, + local_geo_data, global_geo_data, + output_and_expansion_dtype, + self.qbx_order, + self.fmm_level_to_order, + source_extra_kwargs, + kernel_extra_kwargs, + self._use_target_specific_qbx) + + if self.comm.Get_rank() == 0: + assert global_geo_data_device is not None + from pytential.qbx.geometry import target_state + if actx.to_numpy(actx.np.any( + actx.thaw(global_geo_data_device.user_target_to_center()) + == target_state.FAILED)): + raise RuntimeError("geometry has failed targets") + + # {{{ geometry data inspection hook + + if self.geometry_data_inspector is not None: + perform_fmm = self.geometry_data_inspector( + insn, bound_expr, global_geo_data_device) + if not perform_fmm: + return [(o.name, 0) for o in insn.outputs] + + # }}} + + # Execute global QBX. + from pytential.qbx.fmm import drive_fmm + timing_data: Dict[str, Any] = {} + all_potentials_on_every_target = drive_fmm( + wrangler, flat_strengths, timing_data) + + if self.comm.Get_rank() == 0: + assert global_geo_data_device is not None + results = [] + + for o in insn.outputs: + target_side_number = target_name_and_side_to_number[ + o.target_name, o.qbx_forced_limit] + target_discr, _ = target_discrs_and_qbx_sides[target_side_number] + target_slice = slice( + *global_geo_data_device.target_info().target_discr_starts[ + target_side_number:target_side_number+2]) + + result = all_potentials_on_every_target[ + o.target_kernel_index][target_slice] + + if isinstance(target_discr, Discretization): + template_ary = actx.thaw(target_discr.nodes()[0]) + result = unflatten(template_ary, result, actx, strict=False) + + results.append((o.name, result)) + + return results, timing_data + else: + results = [(o.name, None) for o in insn.outputs] + return results, timing_data diff --git a/pytential/qbx/fmm.py b/pytential/qbx/fmm.py index 34fdfa5c9..71e983163 100644 --- a/pytential/qbx/fmm.py +++ b/pytential/qbx/fmm.py @@ -88,6 +88,28 @@ def wrangler_cls(self): return QBXExpansionWrangler +def _reorder_and_finalize_potentials( + wrangler, non_qbx_potentials, qbx_potentials, template_ary): + nqbtl = wrangler.geo_data.non_qbx_box_target_lists() + + all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary) + + for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials): + ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i + + all_potentials_in_tree_order += qbx_potentials + + def reorder_and_finalize_potentials(x): + # "finalize" gives host FMMs (like FMMlib) a chance to turn the + # potential back into a CL array. + return wrangler.finalize_potentials(x[ + wrangler.geo_data.traversal().tree.sorted_target_ids], template_ary) + + from pytools.obj_array import obj_array_vectorize + return obj_array_vectorize( + reorder_and_finalize_potentials, all_potentials_in_tree_order) + + class QBXExpansionWrangler(SumpyExpansionWrangler): """A specialized implementation of the :class:`boxtree.fmm.ExpansionWranglerInterface` for the QBX FMM. @@ -382,6 +404,17 @@ def eval_target_specific_qbx_locals(self, src_weight_vecs): return (self.full_output_zeros(template_ary), SumpyTimingFuture(template_ary.queue, events=())) + def gather_non_qbx_potentials(self, non_qbx_potentials): + return non_qbx_potentials + + def gather_qbx_potentials(self, qbx_potentials): + return qbx_potentials + + def reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary): + return _reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary) + # }}} @@ -416,10 +449,7 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None): See also :func:`boxtree.fmm.drive_fmm`. """ wrangler = expansion_wrangler - - geo_data = wrangler.geo_data traversal = wrangler.traversal - tree = traversal.tree template_ary = src_weight_vecs[0] @@ -433,6 +463,9 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None): src_weight_vecs = [wrangler.reorder_sources(weight) for weight in src_weight_vecs] + src_weight_vecs = wrangler.distribute_source_weights( + src_weight_vecs, wrangler.geo_data.src_idx_all_ranks()) + # {{{ construct local multipoles mpole_exps, timing_future = wrangler.form_multipoles( @@ -455,6 +488,12 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None): # }}} + # {{{ Communicate mpoles + + wrangler.communicate_mpoles(mpole_exps) + + # }}} + # {{{ direct evaluation from neighbor source boxes ("list 1") non_qbx_potentials, timing_future = wrangler.eval_direct( @@ -581,23 +620,11 @@ def drive_fmm(expansion_wrangler, src_weight_vecs, timing_data=None): # {{{ reorder potentials - nqbtl = geo_data.non_qbx_box_target_lists() - - all_potentials_in_tree_order = wrangler.full_output_zeros(template_ary) + non_qbx_potentials = wrangler.gather_non_qbx_potentials(non_qbx_potentials) + qbx_potentials = wrangler.gather_qbx_potentials(qbx_potentials) - for ap_i, nqp_i in zip(all_potentials_in_tree_order, non_qbx_potentials): - ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i - - all_potentials_in_tree_order += qbx_potentials - - def reorder_and_finalize_potentials(x): - # "finalize" gives host FMMs (like FMMlib) a chance to turn the - # potential back into a CL array. - return wrangler.finalize_potentials(x[tree.sorted_target_ids], template_ary) - - from pytools.obj_array import obj_array_vectorize - result = obj_array_vectorize( - reorder_and_finalize_potentials, all_potentials_in_tree_order) + result = wrangler.reorder_and_finalize_potentials( + non_qbx_potentials, qbx_potentials, template_ary) # }}} diff --git a/pytential/qbx/fmmlib.py b/pytential/qbx/fmmlib.py index adb1b5298..db144ca2b 100644 --- a/pytential/qbx/fmmlib.py +++ b/pytential/qbx/fmmlib.py @@ -30,6 +30,7 @@ Kernel, FMMLibTreeIndependentDataForWrangler, FMMLibExpansionWrangler) +from boxtree.distributed.calculation import DistributedFMMLibExpansionWrangler from sumpy.kernel import ( LaplaceKernel, HelmholtzKernel, AxisTargetDerivative, DirectionalSourceDerivative) @@ -149,6 +150,20 @@ def wrangler_cls(self): # {{{ fmmlib expansion wrangler +def boxtree_fmm_level_to_order(fmm_level_to_order, helmholtz_k): + def inner_fmm_level_to_order(tree, level): + if helmholtz_k == 0: + return fmm_level_to_order( + LaplaceKernel(tree.dimensions), + frozenset(), tree, level) + else: + return fmm_level_to_order( + HelmholtzKernel(tree.dimensions), + frozenset([("k", helmholtz_k)]), tree, level) + + return inner_fmm_level_to_order + + class QBXFMMLibExpansionWrangler(FMMLibExpansionWrangler): def __init__(self, tree_indep, geo_data, dtype, qbx_order, fmm_level_to_order, @@ -157,9 +172,9 @@ def __init__(self, tree_indep, geo_data, dtype, _use_target_specific_qbx=None): # FMMLib is CPU-only. This wrapper gets the geometry out of # OpenCL-land. - - from pytential.qbx.utils import ToHostTransferredGeoDataWrapper - geo_data = ToHostTransferredGeoDataWrapper(geo_data) + if hasattr(geo_data, "_setup_actx"): + from pytential.qbx.utils import ToHostTransferredGeoDataWrapper + geo_data = ToHostTransferredGeoDataWrapper(geo_data) self.geo_data = geo_data self.qbx_order = qbx_order @@ -178,17 +193,8 @@ def __init__(self, tree_indep, geo_data, dtype, tree_indep.source_deriv_name]], order="F") - def inner_fmm_level_to_order(tree, level): - if helmholtz_k == 0: - return fmm_level_to_order( - LaplaceKernel(tree.dimensions), - frozenset(), tree, level) - else: - return fmm_level_to_order( - HelmholtzKernel(tree.dimensions), - frozenset([("k", helmholtz_k)]), tree, level) - - super().__init__( + FMMLibExpansionWrangler.__init__( + self, tree_indep, geo_data.traversal(), @@ -196,7 +202,8 @@ def inner_fmm_level_to_order(tree, level): dipole_vec=dipole_vec, dipoles_already_reordered=True, - fmm_level_to_order=inner_fmm_level_to_order, + fmm_level_to_order=boxtree_fmm_level_to_order( + fmm_level_to_order, helmholtz_k), rotation_data=geo_data) # {{{ data vector helpers @@ -222,6 +229,9 @@ def full_output_zeros(self, template_ary): np.zeros(self.tree.ntargets, self.tree_indep.dtype) for k in self.tree_indep.outputs]) + def eval_qbx_output_zeros(self, template_ary): + return self.full_output_zeros(template_ary) + def reorder_sources(self, source_array): if isinstance(source_array, cl.array.Array): source_array = source_array.get() @@ -547,7 +557,7 @@ def translate_box_local_to_qbx_local(self, local_exps): @log_process(logger) @return_timing_data def eval_qbx_expansions(self, qbx_expansions): - output = self.full_output_zeros(template_ary=qbx_expansions) + output = self.eval_qbx_output_zeros(template_ary=qbx_expansions) geo_data = self.geo_data ctt = geo_data.center_to_tree_targets() @@ -555,7 +565,7 @@ def eval_qbx_expansions(self, qbx_expansions): qbx_centers = geo_data.centers() qbx_radii = geo_data.expansion_radii() - all_targets = geo_data.all_targets() + all_targets = geo_data.eval_qbx_targets() taeval = self.tree_indep.get_expn_eval_routine("ta") @@ -583,8 +593,11 @@ def eval_qbx_expansions(self, qbx_expansions): @return_timing_data def eval_target_specific_qbx_locals(self, src_weight_vecs): src_weights, = src_weight_vecs + output = self.eval_qbx_output_zeros(template_ary=src_weights) + noutput_targets = len(output[0]) + if not self.tree_indep.using_tsqbx: - return self.full_output_zeros(template_ary=src_weights) + return output geo_data = self.geo_data trav = geo_data.traversal() @@ -600,9 +613,9 @@ def eval_target_specific_qbx_locals(self, src_weight_vecs): ifgrad = self.tree_indep.ifgrad # Create temporary output arrays for potential / gradient. - pot = np.zeros(self.tree.ntargets, np.complex128) if ifpot else None + pot = np.zeros(noutput_targets, np.complex128) if ifpot else None grad = ( - np.zeros((self.dim, self.tree.ntargets), np.complex128) + np.zeros((self.dim, noutput_targets), np.complex128) if ifgrad else None) ts.eval_target_specific_qbx_locals( @@ -612,7 +625,7 @@ def eval_target_specific_qbx_locals(self, src_weight_vecs): ifdipole=ifdipole, order=self.qbx_order, sources=self._get_single_sources_array(), - targets=geo_data.all_targets(), + targets=geo_data.eval_qbx_targets(), centers=self._get_single_centers_array(), qbx_centers=geo_data.global_qbx_centers(), qbx_center_to_target_box=geo_data.qbx_center_to_target_box(), @@ -629,15 +642,153 @@ def eval_target_specific_qbx_locals(self, src_weight_vecs): pot=pot, grad=grad) - output = self.full_output_zeros(template_ary=src_weights) self.add_potgrad_onto_output(output, slice(None), pot, grad) return output + def gather_non_qbx_potentials(self, non_qbx_potentials): + return non_qbx_potentials + + def gather_qbx_potentials(self, qbx_potentials): + return qbx_potentials + + def reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary): + from pytential.qbx.fmm import _reorder_and_finalize_potentials + return _reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary) + def finalize_potentials(self, potential, template_ary): potential = super().finalize_potentials(potential, template_ary) return cl.array.to_device(template_ary.queue, potential) # }}} + +class DistributedQBXFMMLibExpansionWrangler( + QBXFMMLibExpansionWrangler, DistributedFMMLibExpansionWrangler): + MPITags = { + "non_qbx_potentials": 0, + "qbx_potentials": 1 + } + + def __init__( + self, context, comm, tree_indep, local_geo_data, global_geo_data, dtype, + qbx_order, fmm_level_to_order, + source_extra_kwargs, + kernel_extra_kwargs, + _use_target_specific_qbx=None, + communicate_mpoles_via_allreduce=False): + self.global_geo_data = global_geo_data + + QBXFMMLibExpansionWrangler.__init__( + self, tree_indep, local_geo_data, dtype, qbx_order, fmm_level_to_order, + source_extra_kwargs, kernel_extra_kwargs, + _use_target_specific_qbx=_use_target_specific_qbx) + + # This is blatantly copied from QBXFMMLibExpansionWrangler, is it worthwhile + # to refactor this? + if tree_indep.k_name is None: + helmholtz_k = 0 + else: + helmholtz_k = kernel_extra_kwargs[tree_indep.k_name] + + DistributedFMMLibExpansionWrangler.__init__( + self, context, comm, tree_indep, + local_geo_data.local_trav, global_geo_data.global_traversal, + boxtree_fmm_level_to_order(fmm_level_to_order, helmholtz_k), + communicate_mpoles_via_allreduce=communicate_mpoles_via_allreduce) + + def reorder_sources(self, source_array): + if self.comm.Get_rank() == 0: + return super().reorder_sources(source_array) + else: + return None + + def eval_qbx_output_zeros(self, template_ary): + from pytools.obj_array import make_obj_array + ctt = self.geo_data.center_to_tree_targets() + output = make_obj_array([np.zeros(len(ctt.lists), self.tree_indep.dtype) + for k in self.tree_indep.outputs]) + return output + + def full_output_zeros(self, template_ary): + """This includes QBX and non-QBX targets.""" + + from pytools.obj_array import make_obj_array + return make_obj_array([ + np.zeros(self.global_traversal.tree.ntargets, self.tree_indep.dtype) + for k in self.tree_indep.outputs]) + + def _gather_tgt_potentials(self, ntargets, potentials, mask, mpi_tag): + mpi_rank = self.comm.Get_rank() + mpi_size = self.comm.Get_size() + + if mpi_rank == 0: + from pytools.obj_array import make_obj_array + potentials_all_rank = make_obj_array([ + np.zeros(ntargets, self.tree_indep.dtype) + for k in self.tree_indep.outputs]) + + for irank in range(mpi_size): + if irank == 0: + potentials_cur_rank = potentials + else: + potentials_cur_rank = self.comm.recv(source=irank, tag=mpi_tag) + + for idim in range(len(self.tree_indep.outputs)): + potentials_all_rank[idim][mask[irank]] = \ + potentials_cur_rank[idim] + + return potentials_all_rank + else: + self.comm.send(potentials, dest=0, tag=mpi_tag) + return None + + def gather_non_qbx_potentials(self, non_qbx_potentials): + ntargets = 0 + if self.comm.Get_rank() == 0: + nqbtl = self.global_geo_data.non_qbx_box_target_lists + ntargets = nqbtl.nfiltered_targets + + return self._gather_tgt_potentials( + ntargets, non_qbx_potentials, + self.geo_data.particle_mask, self.MPITags["non_qbx_potentials"]) + + def gather_qbx_potentials(self, qbx_potentials): + ntargets = 0 + if self.comm.Get_rank() == 0: + ntargets = self.global_traversal.tree.ntargets + + return self._gather_tgt_potentials( + ntargets, qbx_potentials, + self.geo_data.qbx_target_mask, self.MPITags["qbx_potentials"]) + + def reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary): + mpi_rank = self.comm.Get_rank() + + if mpi_rank == 0: + all_potentials_in_tree_order = self.full_output_zeros(template_ary) + + nqbtl = self.global_geo_data.non_qbx_box_target_lists + + for ap_i, nqp_i in zip( + all_potentials_in_tree_order, non_qbx_potentials): + ap_i[nqbtl.unfiltered_from_filtered_target_indices] = nqp_i + + all_potentials_in_tree_order += qbx_potentials + + def _reorder_and_finalize_potentials(x): + # "finalize" gives host FMMs (like FMMlib) a chance to turn the + # potential back into a CL array. + return self.finalize_potentials( + x[self.global_traversal.tree.sorted_target_ids], template_ary) + + from pytools.obj_array import with_object_array_or_scalar + return with_object_array_or_scalar( + _reorder_and_finalize_potentials, all_potentials_in_tree_order) + else: + return None + # vim: foldmethod=marker diff --git a/pytential/qbx/geometry.py b/pytential/qbx/geometry.py index 71e08be66..5c5bb304b 100644 --- a/pytential/qbx/geometry.py +++ b/pytential/qbx/geometry.py @@ -881,6 +881,9 @@ def m2l_rotation_angles(self): .build_rotation_classes_lists() .from_sep_siblings_rotation_class_to_angle) + def src_idx_all_ranks(self): + return None + # {{{ plotting (for debugging) def plot(self, draw_circles=False, draw_center_numbers=False, diff --git a/pytential/qbx/utils.py b/pytential/qbx/utils.py index 3a51a5e65..5ed223cdc 100644 --- a/pytential/qbx/utils.py +++ b/pytential/qbx/utils.py @@ -479,6 +479,9 @@ def all_targets(self): """All (not just non-QBX) targets packaged into a single array.""" return np.array(list(self.tree().targets)) + def eval_qbx_targets(self): + return self.all_targets() + def m2l_rotation_lists(self): # Already on host return self.geo_data.m2l_rotation_lists() @@ -487,6 +490,9 @@ def m2l_rotation_angles(self): # Already on host return self.geo_data.m2l_rotation_angles() + def src_idx_all_ranks(self): + return None + # }}} # vim: foldmethod=marker:filetype=pyopencl diff --git a/pytential/symbolic/execution.py b/pytential/symbolic/execution.py index 3efbee43c..b086d80eb 100644 --- a/pytential/symbolic/execution.py +++ b/pytential/symbolic/execution.py @@ -358,6 +358,88 @@ def exec_compute_potential_insn( return result + +class DistributedEvaluationMapper(EvaluationMapper): + def __init__(self, comm, bound_expr, actx, context=None, timing_data=None): + self.comm = comm + + if self.comm.Get_rank() == 0: + super().__init__(bound_expr, actx, context, timing_data) + else: + self.bound_expr = bound_expr + self.array_context = actx + self.context = context + self.places = None + self.timing_data = timing_data + + def exec_assign(self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate): + if self.comm.Get_rank() == 0: + return super().exec_assign(actx, insn, bound_expr, evaluate) + else: + return [(name, None) for name in insn.names] + + def exec_compute_potential_insn( + self, actx: PyOpenCLArrayContext, insn, bound_expr, evaluate): + from pytential.qbx.distributed import DistributedQBXLayerPotentialSource + + mpi_rank = self.comm.Get_rank() + use_target_specific_qbx = None + fmm_backend = None + qbx_order = None + fmm_level_to_order = None + expansion_factory = None + + if mpi_rank == 0: + source: DistributedQBXLayerPotentialSource = \ + bound_expr.places.get_geometry(insn.source.geometry) + if not isinstance(source, DistributedQBXLayerPotentialSource): + raise TypeError("Distributed execution mapper can only process" + "distributed layer potential source") + + use_target_specific_qbx = source._use_target_specific_qbx + fmm_backend = source.fmm_backend + qbx_order = source.qbx_order + fmm_level_to_order = source.fmm_level_to_order + expansion_factory = source.expansion_factory + + use_target_specific_qbx = self.comm.bcast( + use_target_specific_qbx, root=0) + fmm_backend = self.comm.bcast(fmm_backend, root=0) + qbx_order = self.comm.bcast(qbx_order, root=0) + fmm_level_to_order = self.comm.bcast(fmm_level_to_order, root=0) + expansion_factory = self.comm.bcast(expansion_factory, root=0) + + assert isinstance(fmm_backend, str) + + if mpi_rank != 0: + source = DistributedQBXLayerPotentialSource( + self.comm, + actx.context, + qbx_order=qbx_order, + fmm_level_to_order=fmm_level_to_order, + _use_target_specific_qbx=use_target_specific_qbx, + fmm_backend=fmm_backend, + expansion_factory=expansion_factory) + + return_timing_data = self.timing_data is not None + result, timing_data = ( + source.exec_compute_potential_insn( + actx, insn, bound_expr, evaluate, return_timing_data)) + + if return_timing_data: + # The compiler ensures this. + assert insn not in self.timing_data + + self.timing_data[insn] = timing_data + + return result + + def __call__(self, expr, *args, **kwargs): + if self.comm.Get_rank() == 0: + return super().__call__(expr, *args, **kwargs) + else: + return None + # }}} @@ -881,6 +963,55 @@ def __call__(self, *args, **kwargs): return self.eval(kwargs, array_context=array_context) +class DistributedBoundExpression(BoundExpression): + def __init__(self, comm, places, sym_op_expr): + self.comm = comm + self._code = None + self._geo_data_cache = {} + + if self.comm.Get_rank() == 0: + super().__init__(places, sym_op_expr) + self._code = super().code + + self._code = self.comm.bcast(self._code, root=0) + + @property + def code(self): + return self._code + + def cost_per_stage(self, calibration_params, **kwargs): + if self.comm.Get_rank() == 0: + return super().cost_per_stage(calibration_params, **kwargs) + else: + raise RuntimeError("Cost model is not available on worker ranks") + + def cost_per_box(self, calibration_params, **kwargs): + if self.comm.Get_rank() == 0: + return super().cost_per_box(calibration_params, **kwargs) + else: + raise RuntimeError("Cost model is not available on worker ranks") + + def scipy_op( + self, actx: PyOpenCLArrayContext, arg_name, dtype, + domains=None, **extra_args): + raise NotImplementedError + + def eval(self, context=None, timing_data=None, + array_context: Optional[PyOpenCLArrayContext] = None): + if context is None: + context = {} + + array_context = _find_array_context_from_args_in_context( + context, array_context) + + exec_mapper = DistributedEvaluationMapper( + self.comm, self, array_context, context, timing_data=timing_data) + return execute(self.code, exec_mapper) + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def bind(places, expr, auto_where=None): """ :arg places: a :class:`pytential.collection.GeometryCollection`. @@ -905,6 +1036,28 @@ def bind(places, expr, auto_where=None): expr = _prepare_expr(places, expr, auto_where=auto_where) return BoundExpression(places, expr) + +def bind_distributed(comm, places, expr, auto_where=None): + """Distributed version of `bind`. + + Overall, this function accepts the same argument as the non-distributed version + on the root rank, with the addition of a MPI communicator. On the worker rank, + only the `comm` argument is significant. + + :arg comm: MPI communicator. + :arg places: a :class:`pytential.collection.GeometryCollection`. Only significant + on the root rank. Worker ranks could simply pass `None`. + """ + if comm.Get_rank() == 0: + from pytential import GeometryCollection + if not isinstance(places, GeometryCollection): + places = GeometryCollection(places, auto_where=auto_where) + auto_where = places.auto_where + + expr = _prepare_expr(places, expr, auto_where=auto_where) + + return DistributedBoundExpression(comm, places, expr) + # }}} diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 9447573f4..3a1d6371f 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -1081,6 +1081,11 @@ def __init__(self, from_dd, to_dd, operand): def __getinitargs__(self): return (self.from_dd, self.to_dd, self.operand) + def __getnewargs__(self): + # Since this class defines `__new__`, `__getnewargs__` is needed to support + # unpickling. + return (self.from_dd, self.to_dd, self.operand) + mapper_method = intern("map_interpolation") diff --git a/requirements.txt b/requirements.txt index 88777f180..cbf4ef4cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ git+https://github.com/inducer/modepy.git#egg=modepy git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/loopy.git#egg=loopy -git+https://github.com/inducer/boxtree.git#egg=boxtree +git+https://github.com/gaohao95/boxtree.git@dist-pytential#egg=boxtree git+https://github.com/inducer/arraycontext.git#egg=arraycontext git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/sumpy.git#egg=sumpy diff --git a/test/test_cost_model.py b/test/test_cost_model.py index 9ca455d54..1d6cf6624 100644 --- a/test/test_cost_model.py +++ b/test/test_cost_model.py @@ -638,6 +638,18 @@ def eval_target_specific_qbx_locals(self, src_weight_vecs): return pot, self.timing_future(ops) + def gather_non_qbx_potentials(self, non_qbx_potentials): + return non_qbx_potentials + + def gather_qbx_potentials(self, qbx_potentials): + return qbx_potentials + + def reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary): + from pytential.qbx.fmm import _reorder_and_finalize_potentials + return _reorder_and_finalize_potentials( + self, non_qbx_potentials, qbx_potentials, template_ary) + # }}} diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 000000000..ba74dedad --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,337 @@ +__copyright__ = "Copyright (C) 2024 Hao Gao" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import numpy.linalg as la +import pyopencl as cl +import pyopencl.clmath +from pyopencl.tools import ( # noqa + pytest_generate_tests_for_pyopencl as pytest_generate_tests) +from arraycontext import flatten, unflatten +from meshmode.array_context import PyOpenCLArrayContext + +from meshmode.mesh.generation import make_curve_mesh, ellipse +from sumpy.visualization import FieldPlotter +from pytential import bind, sym, GeometryCollection +from boxtree.tools import run_mpi + +import pytest +from functools import partial +import sys +import os + +from mpi4py import MPI +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +import logging +logger = logging.getLogger(__name__) + + +# {{{ test off-surface eval + +def _test_off_surface_eval(ctx_factory, use_fmm, do_plot=False): + logging.basicConfig(level=logging.INFO) + + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + actx = PyOpenCLArrayContext(queue) + + # prevent cache 'splosion + from sympy.core.cache import clear_cache + clear_cache() + + places = None + op = None + sigma = None + + if rank == 0: + nelements = 30 + target_order = 8 + qbx_order = 3 + + if use_fmm: + fmm_order = qbx_order + else: + fmm_order = False + + mesh = make_curve_mesh(partial(ellipse, 3), + np.linspace(0, 1, nelements+1), + target_order) + + from pytential.qbx.distributed import DistributedQBXLayerPotentialSource + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + + pre_density_discr = Discretization( + actx, mesh, InterpolatoryQuadratureSimplexGroupFactory(target_order)) + layer_pot_source = DistributedQBXLayerPotentialSource( + comm, + cl_ctx, + pre_density_discr, + fine_order=4*target_order, + qbx_order=qbx_order, + fmm_order=fmm_order, + fmm_backend="fmmlib") + + from pytential.target import PointsTarget + fplot = FieldPlotter(np.zeros(2), extent=0.54, npoints=30) + targets = PointsTarget(fplot.points) + + places = GeometryCollection((layer_pot_source, targets)) + + from sumpy.kernel import LaplaceKernel + op = sym.D(LaplaceKernel(2), sym.var("sigma"), qbx_forced_limit=-2) + + sigma = layer_pot_source.density_discr.zeros(actx) + 1 + + from pytential.symbolic.execution import bind_distributed + bound_op = bind_distributed(comm, places, op) + fld_in_vol = bound_op.eval(context={"sigma": sigma}, array_context=actx) + + if rank == 0: + # test against shared memory result + from pytential.qbx import QBXLayerPotentialSource + qbx = QBXLayerPotentialSource( + pre_density_discr, + 4 * target_order, + qbx_order, + fmm_order=fmm_order, + _from_sep_smaller_min_nsources_cumul=0 + ) + + places = GeometryCollection((qbx, targets)) + fld_in_vol_single_node = bind(places, op)(actx, sigma=sigma) + + linf_err = ( + cl.array.max(cl.clmath.fabs(fld_in_vol - fld_in_vol_single_node)) + / cl.array.max(cl.clmath.fabs(fld_in_vol_single_node))) + + print("l_inf error:", linf_err) + assert linf_err < 1e-13 + + +@pytest.mark.mpi +@pytest.mark.parametrize("num_processes, use_fmm", [ + (4, True), +]) +@pytest.mark.skipif(sys.version_info < (3, 5), + reason="distributed implementation requires 3.5 or higher") +def test_off_surface_eval( + num_processes, use_fmm, do_plot=False): + pytest.importorskip("mpi4py") + + newenv = os.environ.copy() + newenv["PYTEST"] = "1" + newenv["OMP_NUM_THREADS"] = "1" + newenv["POCL_MAX_PTHREAD_COUNT"] = "1" + newenv["use_fmm"] = str(use_fmm) + newenv["do_plot"] = str(do_plot) + + run_mpi(__file__, num_processes, newenv) + +# }}} + + +# {{{ compare on-surface urchin geometry against single-rank result + +def single_layer_wrapper(kernel): + u_sym = sym.var("u") + return sym.S(kernel, u_sym, qbx_forced_limit=-1) + + +def double_layer_wrapper(kernel): + u_sym = sym.var("u") + return sym.D(kernel, u_sym, qbx_forced_limit="avg") + + +def _test_urchin_against_single_rank(ctx_factory, m, n, op_wrapper, use_tsqbx): + logging.basicConfig(level=logging.INFO) + + qbx_order = 3 + fmm_order = 10 + target_order = 8 + est_rel_interp_tolerance = 1e-10 + _expansion_stick_out_factor = 0.5 + + ctx = ctx_factory() + queue = cl.CommandQueue(ctx) + actx = PyOpenCLArrayContext(queue) + + # prevent cache 'splosion + from sympy.core.cache import clear_cache + clear_cache() + + if rank == 0: + from meshmode.mesh.generation import generate_urchin + mesh = generate_urchin(target_order, m, n, est_rel_interp_tolerance) + d = mesh.ambient_dim + + from sumpy.kernel import LaplaceKernel + k_sym = LaplaceKernel(d) + op = op_wrapper(k_sym) + + from meshmode.discretization import Discretization + from meshmode.discretization.poly_element import \ + InterpolatoryQuadratureSimplexGroupFactory + + pre_density_discr = Discretization( + actx, mesh, + InterpolatoryQuadratureSimplexGroupFactory(target_order)) + + params = { + "qbx_order": qbx_order, + "fmm_order": fmm_order, + "fmm_backend": "fmmlib", + "_from_sep_smaller_min_nsources_cumul": 0, + "_expansions_in_tree_have_extent": True, + "_expansion_stick_out_factor": _expansion_stick_out_factor, + "_use_target_specific_qbx": use_tsqbx + } + + from pytential.qbx.distributed import DistributedQBXLayerPotentialSource + qbx = DistributedQBXLayerPotentialSource( + comm, + ctx, + density_discr=pre_density_discr, + fine_order=4 * target_order, + # knl_specific_calibration_params="constant_one", + **params) + + places = GeometryCollection(qbx) + density_discr = places.get_discretization(places.auto_source.geometry) + + # {{{ compute values of a solution to the PDE + + nodes_host = actx.to_numpy( + flatten(density_discr.nodes(), actx)).reshape(d, -1) + + center = np.array([3, 1, 2])[:d] + diff = nodes_host - center[:, np.newaxis] + dist_squared = np.sum(diff ** 2, axis=0) + dist = np.sqrt(dist_squared) + if d == 2: + u = np.log(dist) + grad_u = diff / dist_squared + elif d == 3: + u = 1 / dist + grad_u = -diff / dist ** 3 + else: + raise RuntimeError("Unsupported dimension") + + # }}} + + u_dev = unflatten( + actx.thaw(density_discr.nodes()[0]), + actx.from_numpy(u), + actx, strict=False) + grad_u_dev = unflatten( + density_discr.nodes(), + actx.from_numpy(grad_u.ravel()), actx, strict=False) + + context = {"u": u_dev, "grad_u": grad_u_dev} + else: + places = None + op = None + context = {"u": None, "grad_u": None} + + from pytential.symbolic.execution import bind_distributed + bound_op = bind_distributed(comm, places, op) + distributed_result = bound_op.eval(context=context, array_context=actx) + + if rank == 0: + from pytential.qbx import QBXLayerPotentialSource + qbx = QBXLayerPotentialSource( + density_discr=pre_density_discr, + fine_order=4 * target_order, + **params) + places = GeometryCollection(qbx) + + context = {"u": u_dev, "grad_u": grad_u_dev} + single_node_result = bind(places, op)(actx, **context) + + distributed_result = actx.to_numpy(flatten(distributed_result, actx)) + single_node_result = actx.to_numpy(flatten(single_node_result, actx)) + + linf_err = la.norm(distributed_result - single_node_result, ord=np.inf) + print("l_inf error:", linf_err) + assert linf_err < 1e-13 + + +@pytest.mark.mpi +@pytest.mark.parametrize("num_processes, m, n, op_wrapper, use_tsqbx", [ + (4, 1, 3, "single_layer_wrapper", True), + (4, 1, 3, "single_layer_wrapper", False), + (4, 1, 3, "double_layer_wrapper", True), + (4, 1, 3, "double_layer_wrapper", False), +]) +@pytest.mark.skipif(sys.version_info < (3, 5), + reason="distributed implementation requires 3.5 or higher") +def test_urchin_against_single_rank( + num_processes, m, n, op_wrapper, use_tsqbx): + pytest.importorskip("mpi4py") + + newenv = os.environ.copy() + newenv["PYTEST"] = "2" + newenv["OMP_NUM_THREADS"] = "1" + newenv["POCL_MAX_PTHREAD_COUNT"] = "1" + newenv["m"] = str(m) + newenv["n"] = str(n) + newenv["op_wrapper"] = op_wrapper + newenv["use_tsqbx"] = str(use_tsqbx) + + run_mpi(__file__, num_processes, newenv) + +# }}} + + +if __name__ == "__main__": + if "PYTEST" in os.environ: + if os.environ["PYTEST"] == "1": + # Run "test_off_surface_eval" test case + use_fmm = (os.environ["use_fmm"] == "True") + do_plot = (os.environ["do_plot"] == "True") + + _test_off_surface_eval(cl.create_some_context, use_fmm, do_plot=do_plot) + elif os.environ["PYTEST"] == "2": + # Run "test_urchin_against_single_rank" test case + m = int(os.environ["m"]) + n = int(os.environ["n"]) + op_wrapper_str = os.environ["op_wrapper"] + use_tsqbx = (os.environ["use_tsqbx"] == "True") + + if op_wrapper_str == "single_layer_wrapper": + op_wrapper = single_layer_wrapper + elif op_wrapper_str == "double_layer_wrapper": + op_wrapper = double_layer_wrapper + else: + raise ValueError("unknown op wrapper") + + _test_urchin_against_single_rank( + cl.create_some_context, m, n, op_wrapper, use_tsqbx) + else: + if len(sys.argv) > 1: + # You can test individual routines by typing + # $ python test_distributed.py 'test_off_surface_eval(4, True, True)' + exec(sys.argv[1])