From 75b47417d44164b423f6c7e959d3b6cff115c642 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Wed, 11 Dec 2024 10:54:30 -0800 Subject: [PATCH 01/18] Added an MPI all_reduce for quantities based on SUM operation to communicator.py --- ndsl/comm/communicator.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index ff270df5..e6cdc3b3 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -2,6 +2,7 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np +from mpi4py import MPI import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer @@ -93,6 +94,25 @@ def _device_synchronize(): # this is a method so we can profile it separately from other device syncs device_synchronize() + def _create_all_reduce_quantity( + self, input_metadata: QuantityMetadata, input_data + ) -> Quantity: + """Create a Quantity for all_reduce data and metadata""" + all_reduce_quantity = Quantity( + input_data, + dims=input_metadata.dims, + units=input_metadata.units, + origin=tuple([0 for dim in input_metadata.dims]), + gt4py_backend=input_metadata.gt4py_backend, + allow_mismatch_float_precision=True, + ) + return all_reduce_quantity + + def all_reduce_sum(self, quantity: Quantity): + reduced_quantity_data = self.comm.allreduce(quantity.data,MPI.SUM) + all_reduce_quantity = self._create_all_reduce_quantity(quantity.metadata, reduced_quantity_data) + return all_reduce_quantity + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( numpy_module.zeros, recvbuf From 4c8632c9a5b1aba1c391d6d878cc0d62f72d8027 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Wed, 11 Dec 2024 13:41:55 -0800 Subject: [PATCH 02/18] linted --- ndsl/comm/communicator.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index e6cdc3b3..862844b3 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -2,7 +2,7 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np -from mpi4py import MPI +from mpi4py import MPI import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer @@ -109,10 +109,12 @@ def _create_all_reduce_quantity( return all_reduce_quantity def all_reduce_sum(self, quantity: Quantity): - reduced_quantity_data = self.comm.allreduce(quantity.data,MPI.SUM) - all_reduce_quantity = self._create_all_reduce_quantity(quantity.metadata, reduced_quantity_data) + reduced_quantity_data = self.comm.allreduce(quantity.data, MPI.SUM) + all_reduce_quantity = self._create_all_reduce_quantity( + quantity.metadata, reduced_quantity_data + ) return all_reduce_quantity - + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( numpy_module.zeros, recvbuf From a2fac9f0df32dda33eb661335e230780ab661a46 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Thu, 12 Dec 2024 21:41:16 -0800 Subject: [PATCH 03/18] Add initial skeleton of pytest test for all reduce --- tests/mpi/test_mpi_all_reduce_sum.py | 49 ++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/mpi/test_mpi_all_reduce_sum.py diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py new file mode 100644 index 00000000..25effc72 --- /dev/null +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -0,0 +1,49 @@ +import pytest + +from tests.mpi.mpi_comm import MPI +# from ndsl.typing import Communicator + +from ndsl import ( + CubedSphereCommunicator, + CubedSpherePartitioner, + Quantity, + TilePartitioner, +) + +from ndsl.comm.partitioner import Partitioner + +@pytest.fixture +def layout(): + if MPI is not None: + size = MPI.COMM_WORLD.Get_size() + ranks_per_tile = size // 6 + ranks_per_edge = int(ranks_per_tile ** 0.5) + return (ranks_per_edge, ranks_per_edge) + else: + return (1, 1) + +@pytest.fixture(params=[0.1, 1.0]) +def edge_interior_ratio(request): + return request.param + +@pytest.fixture +def tile_partitioner(layout, edge_interior_ratio: float): + return TilePartitioner(layout, edge_interior_ratio=edge_interior_ratio) + +@pytest.fixture +def cube_partitioner(tile_partitioner): + return CubedSpherePartitioner(tile_partitioner) + +@pytest.fixture() +def communicator(cube_partitioner): + return CubedSphereCommunicator( + comm=MPI.COMM_WORLD, + partitioner=cube_partitioner, + ) + +def test_all_reduce_sum( + communicator, +): + print("Communicator rank = ", communicator.rank) + print("Communicator size = ", communicator.size) + assert True \ No newline at end of file From 8c5b5d5bd4979cac763866989f3760341df5b171 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Fri, 13 Dec 2024 09:43:34 -0800 Subject: [PATCH 04/18] Added assertion tests for 1, 2 and 3D quantities passed through mpi_allreduce_sum --- tests/mpi/test_mpi_all_reduce_sum.py | 55 +++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 25effc72..f03787ee 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -1,7 +1,6 @@ import pytest - +import numpy as np from tests.mpi.mpi_comm import MPI -# from ndsl.typing import Communicator from ndsl import ( CubedSphereCommunicator, @@ -10,8 +9,11 @@ TilePartitioner, ) +from ndsl.quantity import Quantity from ndsl.comm.partitioner import Partitioner +from ndsl.dsl.typing import Float + @pytest.fixture def layout(): if MPI is not None: @@ -44,6 +46,49 @@ def communicator(cube_partitioner): def test_all_reduce_sum( communicator, ): - print("Communicator rank = ", communicator.rank) - print("Communicator size = ", communicator.size) - assert True \ No newline at end of file + + backend = "numpy" + base_array = np.array([i for i in range(5)], dtype=Float) + + testQuantity_1D = Quantity( + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5*5)], dtype=Float) + base_array = base_array.reshape(5,5) + + testQuantity_2D = Quantity( + data=base_array, + dims=["I","J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5*5*5)], dtype=Float) + base_array = base_array.reshape(5,5,5) + + testQuantity_3D = Quantity( + data=base_array, + dims=["I","J","K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + + # print("Communicator rank = ", communicator.rank) + # print("Communicator size = ", communicator.size) + # print("nsize = ", nsize) + + global_sum_q = communicator.all_reduce_sum(testQuantity_1D) + assert global_sum_q.metadata == testQuantity_1D.metadata + assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_2D) + assert global_sum_q.metadata == testQuantity_2D.metadata + assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_3D) + assert global_sum_q.metadata == testQuantity_3D.metadata + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() \ No newline at end of file From fb4e74010615f15f962a8a0572ed48f1267b5581 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Fri, 13 Dec 2024 09:48:09 -0800 Subject: [PATCH 05/18] Linted --- tests/mpi/test_mpi_all_reduce_sum.py | 64 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index f03787ee..728ec4f8 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -1,6 +1,5 @@ -import pytest import numpy as np -from tests.mpi.mpi_comm import MPI +import pytest from ndsl import ( CubedSphereCommunicator, @@ -8,11 +7,9 @@ Quantity, TilePartitioner, ) - -from ndsl.quantity import Quantity -from ndsl.comm.partitioner import Partitioner - from ndsl.dsl.typing import Float +from tests.mpi.mpi_comm import MPI + @pytest.fixture def layout(): @@ -24,18 +21,22 @@ def layout(): else: return (1, 1) + @pytest.fixture(params=[0.1, 1.0]) def edge_interior_ratio(request): return request.param + @pytest.fixture def tile_partitioner(layout, edge_interior_ratio: float): return TilePartitioner(layout, edge_interior_ratio=edge_interior_ratio) + @pytest.fixture def cube_partitioner(tile_partitioner): return CubedSpherePartitioner(tile_partitioner) + @pytest.fixture() def communicator(cube_partitioner): return CubedSphereCommunicator( @@ -43,43 +44,40 @@ def communicator(cube_partitioner): partitioner=cube_partitioner, ) + def test_all_reduce_sum( - communicator, + communicator, ): - + backend = "numpy" base_array = np.array([i for i in range(5)], dtype=Float) testQuantity_1D = Quantity( - data=base_array, - dims=["K"], - units="Some 1D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5*5)], dtype=Float) - base_array = base_array.reshape(5,5) + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) testQuantity_2D = Quantity( - data=base_array, - dims=["I","J"], - units="Some 2D unit", - gt4py_backend=backend, - ) + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) - base_array = np.array([i for i in range(5*5*5)], dtype=Float) - base_array = base_array.reshape(5,5,5) + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) testQuantity_3D = Quantity( - data=base_array, - dims=["I","J","K"], - units="Some 3D unit", - gt4py_backend=backend, - ) - - # print("Communicator rank = ", communicator.rank) - # print("Communicator size = ", communicator.size) - # print("nsize = ", nsize) + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) global_sum_q = communicator.all_reduce_sum(testQuantity_1D) assert global_sum_q.metadata == testQuantity_1D.metadata @@ -91,4 +89,4 @@ def test_all_reduce_sum( global_sum_q = communicator.all_reduce_sum(testQuantity_3D) assert global_sum_q.metadata == testQuantity_3D.metadata - assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() \ No newline at end of file + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() From 34f82fb5ac6275f6a3de1b71794cda25d7fc3495 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Fri, 13 Dec 2024 10:55:18 -0800 Subject: [PATCH 06/18] Added pytest.mark to skip test if mpi4py isn't available --- tests/mpi/test_mpi_all_reduce_sum.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 728ec4f8..caddfa5f 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -45,6 +45,9 @@ def communicator(cube_partitioner): ) +@pytest.mark.skipif( + MPI is None, reason="mpi4py is not available or pytest was not run in parallel" +) def test_all_reduce_sum( communicator, ): From b4a6a5421149b43bf7c64ea2ca4bf86fe9ed2724 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Mon, 16 Dec 2024 08:45:42 -0800 Subject: [PATCH 07/18] lint changes --- ndsl/comm/communicator.py | 7 ++- tests/mpi/test_mpi_all_reduce_sum.py | 84 ++++++++++++++-------------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 862844b3..5f19b2eb 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -102,9 +102,10 @@ def _create_all_reduce_quantity( input_data, dims=input_metadata.dims, units=input_metadata.units, - origin=tuple([0 for dim in input_metadata.dims]), + origin=input_metadata.origin, + extent=input_metadata.extent, gt4py_backend=input_metadata.gt4py_backend, - allow_mismatch_float_precision=True, + allow_mismatch_float_precision=False, ) return all_reduce_quantity @@ -114,6 +115,8 @@ def all_reduce_sum(self, quantity: Quantity): quantity.metadata, reduced_quantity_data ) return all_reduce_quantity + # quantity.data = reduced_quantity_data + # return quantity def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index caddfa5f..9ba01e0a 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -52,44 +52,46 @@ def test_all_reduce_sum( communicator, ): - backend = "numpy" - base_array = np.array([i for i in range(5)], dtype=Float) - - testQuantity_1D = Quantity( - data=base_array, - dims=["K"], - units="Some 1D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5 * 5)], dtype=Float) - base_array = base_array.reshape(5, 5) - - testQuantity_2D = Quantity( - data=base_array, - dims=["I", "J"], - units="Some 2D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) - base_array = base_array.reshape(5, 5, 5) - - testQuantity_3D = Quantity( - data=base_array, - dims=["I", "J", "K"], - units="Some 3D unit", - gt4py_backend=backend, - ) - - global_sum_q = communicator.all_reduce_sum(testQuantity_1D) - assert global_sum_q.metadata == testQuantity_1D.metadata - assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() - - global_sum_q = communicator.all_reduce_sum(testQuantity_2D) - assert global_sum_q.metadata == testQuantity_2D.metadata - assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() - - global_sum_q = communicator.all_reduce_sum(testQuantity_3D) - assert global_sum_q.metadata == testQuantity_3D.metadata - assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() + backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] + + for backend in backends: + base_array = np.array([i for i in range(5)], dtype=Float) + + testQuantity_1D = Quantity( + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + + global_sum_q = communicator.all_reduce_sum(testQuantity_1D) + assert global_sum_q.metadata == testQuantity_1D.metadata + assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_2D) + assert global_sum_q.metadata == testQuantity_2D.metadata + assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_3D) + assert global_sum_q.metadata == testQuantity_3D.metadata + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() From f5ce8831a7918ed45be3fb89f962c558ccdc7035 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Mon, 16 Dec 2024 08:45:42 -0800 Subject: [PATCH 08/18] Addressed PR comments and added additional CPU backends to unit test --- ndsl/comm/communicator.py | 7 ++- tests/mpi/test_mpi_all_reduce_sum.py | 84 ++++++++++++++-------------- 2 files changed, 48 insertions(+), 43 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 862844b3..5f19b2eb 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -102,9 +102,10 @@ def _create_all_reduce_quantity( input_data, dims=input_metadata.dims, units=input_metadata.units, - origin=tuple([0 for dim in input_metadata.dims]), + origin=input_metadata.origin, + extent=input_metadata.extent, gt4py_backend=input_metadata.gt4py_backend, - allow_mismatch_float_precision=True, + allow_mismatch_float_precision=False, ) return all_reduce_quantity @@ -114,6 +115,8 @@ def all_reduce_sum(self, quantity: Quantity): quantity.metadata, reduced_quantity_data ) return all_reduce_quantity + # quantity.data = reduced_quantity_data + # return quantity def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index caddfa5f..9ba01e0a 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -52,44 +52,46 @@ def test_all_reduce_sum( communicator, ): - backend = "numpy" - base_array = np.array([i for i in range(5)], dtype=Float) - - testQuantity_1D = Quantity( - data=base_array, - dims=["K"], - units="Some 1D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5 * 5)], dtype=Float) - base_array = base_array.reshape(5, 5) - - testQuantity_2D = Quantity( - data=base_array, - dims=["I", "J"], - units="Some 2D unit", - gt4py_backend=backend, - ) - - base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) - base_array = base_array.reshape(5, 5, 5) - - testQuantity_3D = Quantity( - data=base_array, - dims=["I", "J", "K"], - units="Some 3D unit", - gt4py_backend=backend, - ) - - global_sum_q = communicator.all_reduce_sum(testQuantity_1D) - assert global_sum_q.metadata == testQuantity_1D.metadata - assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() - - global_sum_q = communicator.all_reduce_sum(testQuantity_2D) - assert global_sum_q.metadata == testQuantity_2D.metadata - assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() - - global_sum_q = communicator.all_reduce_sum(testQuantity_3D) - assert global_sum_q.metadata == testQuantity_3D.metadata - assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() + backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] + + for backend in backends: + base_array = np.array([i for i in range(5)], dtype=Float) + + testQuantity_1D = Quantity( + data=base_array, + dims=["K"], + units="Some 1D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + + global_sum_q = communicator.all_reduce_sum(testQuantity_1D) + assert global_sum_q.metadata == testQuantity_1D.metadata + assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_2D) + assert global_sum_q.metadata == testQuantity_2D.metadata + assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() + + global_sum_q = communicator.all_reduce_sum(testQuantity_3D) + assert global_sum_q.metadata == testQuantity_3D.metadata + assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() From 2e669dbae2fccce6c65dac33db9acf6a5ec564ac Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Wed, 18 Dec 2024 11:36:07 -0800 Subject: [PATCH 09/18] Added setters for various Quantity properties to enable setting of Quantity metadata and data properties. --- ndsl/comm/communicator.py | 30 +++++++++++++----- ndsl/quantity.py | 30 ++++++++++++++++++ tests/mpi/test_mpi_all_reduce_sum.py | 47 ++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 5f19b2eb..55efa961 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -109,14 +109,28 @@ def _create_all_reduce_quantity( ) return all_reduce_quantity - def all_reduce_sum(self, quantity: Quantity): - reduced_quantity_data = self.comm.allreduce(quantity.data, MPI.SUM) - all_reduce_quantity = self._create_all_reduce_quantity( - quantity.metadata, reduced_quantity_data - ) - return all_reduce_quantity - # quantity.data = reduced_quantity_data - # return quantity + def all_reduce_sum( + self, input_quantity: Quantity, output_quantity: Quantity = None + ): + reduced_quantity_data = self.comm.allreduce(input_quantity.data, MPI.SUM) + if output_quantity is None: + all_reduce_quantity = self._create_all_reduce_quantity( + input_quantity.metadata, reduced_quantity_data + ) + return all_reduce_quantity + else: + if output_quantity.data.shape != input_quantity.data.shape: + raise TypeError("Shapes not matching") + + output_quantity.metadata.dims = input_quantity.metadata.dims + output_quantity.metadata.units = input_quantity.metadata.units + output_quantity.metadata.origin = input_quantity.metadata.origin + output_quantity.metadata.extent = input_quantity.metadata.extent + output_quantity.metadata.gt4py_backend = ( + input_quantity.metadata.gt4py_backend + ) + + output_quantity.data = reduced_quantity_data def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( diff --git a/ndsl/quantity.py b/ndsl/quantity.py index b95a9aad..80bb4d06 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity.py @@ -458,10 +458,20 @@ def units(self) -> str: """units of the quantity""" return self.metadata.units + @units.setter + def units(self, newUnits): + if type(newUnits) is str: + self.metadata.units = newUnits + @property def gt4py_backend(self) -> Union[str, None]: return self.metadata.gt4py_backend + @gt4py_backend.setter + def gt4py_backend(self, newBackend): + if type(newBackend) is Union[str, None]: + self.metadata.gt4py_backend = newBackend + @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -471,6 +481,11 @@ def dims(self) -> Tuple[str, ...]: """names of each dimension""" return self.metadata.dims + @dims.setter + def dims(self, newDims): + if type(newDims) is Tuple: + self.metadata.dims = newDims + @property def values(self) -> np.ndarray: warnings.warn( @@ -492,16 +507,31 @@ def data(self) -> Union[np.ndarray, cupy.ndarray]: """the underlying array of data""" return self._data + @data.setter + def data(self, inputData): + if type(inputData) in [np.ndarray, cupy.ndarray]: + self._data = inputData + @property def origin(self) -> Tuple[int, ...]: """the start of the computational domain""" return self.metadata.origin + @origin.setter + def origin(self, newOrigin): + if type(newOrigin) is Tuple: + self.metadata.origin = newOrigin + @property def extent(self) -> Tuple[int, ...]: """the shape of the computational domain""" return self.metadata.extent + @extent.setter + def extent(self, newExtent): + if type(newExtent) is Tuple: + self.metadata.extent = newExtent + @property def data_array(self) -> xr.DataArray: return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 9ba01e0a..9c2b3a3a 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -95,3 +95,50 @@ def test_all_reduce_sum( global_sum_q = communicator.all_reduce_sum(testQuantity_3D) assert global_sum_q.metadata == testQuantity_3D.metadata assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() + + base_array = np.array([i for i in range(5)], dtype=Float) + testQuantity_1D_out = Quantity( + data=base_array, + dims=["K"], + units="New 1D unit", + gt4py_backend=backend, + origin=(8,), + extent=(7,), + ) + + base_array = np.array([i for i in range(5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5) + + testQuantity_2D_out = Quantity( + data=base_array, + dims=["I", "J"], + units="Some 2D unit", + gt4py_backend=backend, + ) + + base_array = np.array([i for i in range(5 * 5 * 5)], dtype=Float) + base_array = base_array.reshape(5, 5, 5) + + testQuantity_3D_out = Quantity( + data=base_array, + dims=["I", "J", "K"], + units="Some 3D unit", + gt4py_backend=backend, + ) + communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) + assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata + assert ( + testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) + ).all() + + communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out) + assert testQuantity_2D_out.metadata == testQuantity_2D.metadata + assert ( + testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size) + ).all() + + communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out) + assert testQuantity_3D_out.metadata == testQuantity_3D.metadata + assert ( + testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size) + ).all() From fd2fa97beb979b7bbd25eeca273b269e742f3de6 Mon Sep 17 00:00:00 2001 From: Christopher Kung Date: Thu, 19 Dec 2024 08:42:56 -0800 Subject: [PATCH 10/18] Added function in QuantityMetadata class that allows copying of Metadata properties from one class to another. Subsequent Quantity setters that performed the copying of QuantityMetadata properties were removed --- ndsl/comm/communicator.py | 8 +------ ndsl/quantity.py | 34 ++++++++-------------------- tests/mpi/test_mpi_all_reduce_sum.py | 2 +- 3 files changed, 11 insertions(+), 33 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 55efa961..ead24896 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -122,13 +122,7 @@ def all_reduce_sum( if output_quantity.data.shape != input_quantity.data.shape: raise TypeError("Shapes not matching") - output_quantity.metadata.dims = input_quantity.metadata.dims - output_quantity.metadata.units = input_quantity.metadata.units - output_quantity.metadata.origin = input_quantity.metadata.origin - output_quantity.metadata.extent = input_quantity.metadata.extent - output_quantity.metadata.gt4py_backend = ( - input_quantity.metadata.gt4py_backend - ) + input_quantity.metadata.duplicate_metadata(output_quantity.metadata) output_quantity.data = reduced_quantity_data diff --git a/ndsl/quantity.py b/ndsl/quantity.py index 80bb4d06..a38a7a5d 100644 --- a/ndsl/quantity.py +++ b/ndsl/quantity.py @@ -53,6 +53,15 @@ def np(self) -> NumpyModule: f"quantity underlying data is of unexpected type {self.data_type}" ) + def duplicate_metadata(self, metadata_copy): + metadata_copy.origin = self.origin + metadata_copy.extent = self.extent + metadata_copy.dims = self.dims + metadata_copy.units = self.units + metadata_copy.data_type = self.data_type + metadata_copy.dtype = self.dtype + metadata_copy.gt4py_backend = self.gt4py_backend + @dataclasses.dataclass class QuantityHaloSpec: @@ -458,20 +467,10 @@ def units(self) -> str: """units of the quantity""" return self.metadata.units - @units.setter - def units(self, newUnits): - if type(newUnits) is str: - self.metadata.units = newUnits - @property def gt4py_backend(self) -> Union[str, None]: return self.metadata.gt4py_backend - @gt4py_backend.setter - def gt4py_backend(self, newBackend): - if type(newBackend) is Union[str, None]: - self.metadata.gt4py_backend = newBackend - @property def attrs(self) -> dict: return dict(**self._attrs, units=self._metadata.units) @@ -481,11 +480,6 @@ def dims(self) -> Tuple[str, ...]: """names of each dimension""" return self.metadata.dims - @dims.setter - def dims(self, newDims): - if type(newDims) is Tuple: - self.metadata.dims = newDims - @property def values(self) -> np.ndarray: warnings.warn( @@ -517,21 +511,11 @@ def origin(self) -> Tuple[int, ...]: """the start of the computational domain""" return self.metadata.origin - @origin.setter - def origin(self, newOrigin): - if type(newOrigin) is Tuple: - self.metadata.origin = newOrigin - @property def extent(self) -> Tuple[int, ...]: """the shape of the computational domain""" return self.metadata.extent - @extent.setter - def extent(self, newExtent): - if type(newExtent) is Tuple: - self.metadata.extent = newExtent - @property def data_array(self) -> xr.DataArray: return xr.DataArray(self.view[:], dims=self.dims, attrs=self.attrs) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 9c2b3a3a..858a7f94 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -126,7 +126,7 @@ def test_all_reduce_sum( gt4py_backend=backend, ) communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) - assert testQuantity_1D_out.metadata == testQuantity_1D_out.metadata + assert testQuantity_1D_out.metadata == testQuantity_1D.metadata assert ( testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) ).all() From cc620c6283dcd35dcf6829a37d411009e6a011c6 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 10:42:42 -0500 Subject: [PATCH 11/18] Add `Allreduce` and all MPI OP --- ndsl/comm/caching_comm.py | 12 +++++++++--- ndsl/comm/comm_abc.py | 26 +++++++++++++++++++++++++- ndsl/comm/communicator.py | 19 +++++++++++++++---- ndsl/comm/mpi.py | 31 ++++++++++++++++++++++++++++--- ndsl/comm/null_comm.py | 10 +++++++--- 5 files changed, 84 insertions(+), 14 deletions(-) diff --git a/ndsl/comm/caching_comm.py b/ndsl/comm/caching_comm.py index 36587d73..42f92ea2 100644 --- a/ndsl/comm/caching_comm.py +++ b/ndsl/comm/caching_comm.py @@ -5,7 +5,7 @@ import numpy as np -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request T = TypeVar("T") @@ -147,9 +147,12 @@ def Split(self, color, key) -> "CachingCommReader": new_data = self._data.get_split() return CachingCommReader(data=new_data) - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: return self._data.get_generic_obj() + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + raise NotImplementedError("CachingCommReader.Allreduce") + @classmethod def load(cls, file: BinaryIO) -> "CachingCommReader": data = CachingCommData.load(file) @@ -229,7 +232,10 @@ def Split(self, color, key) -> "CachingCommWriter": def dump(self, file: BinaryIO): self._data.dump(file) - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: result = self._comm.allreduce(sendobj, op) self._data.generic_obj_buffers.append(copy.deepcopy(result)) return result + + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + raise NotImplementedError("CachingCommWriter.Allreduce") diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index 77f56586..8192560b 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -1,10 +1,30 @@ import abc +import enum from typing import List, Optional, TypeVar T = TypeVar("T") +@enum.unique +class ReductionOperator(enum.Enum): + OP_NULL = enum.auto() + MAX = enum.auto() + MIN = enum.auto() + SUM = enum.auto() + PROD = enum.auto() + LAND = enum.auto() + BAND = enum.auto() + LOR = enum.auto() + BOR = enum.auto() + LXOR = enum.auto() + BXOR = enum.auto() + MAXLOC = enum.auto() + MINLOC = enum.auto() + REPLACE = enum.auto() + NO_OP = enum.auto() + + class Request(abc.ABC): @abc.abstractmethod def wait(self): @@ -69,5 +89,9 @@ def Split(self, color, key) -> "Comm": ... @abc.abstractmethod - def allreduce(self, sendobj: T, op=None) -> T: + def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: + ... + + @abc.abstractmethod + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ... diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index ead24896..ed1f264d 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -2,11 +2,11 @@ from typing import List, Mapping, Optional, Sequence, Tuple, Union, cast import numpy as np -from mpi4py import MPI import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer from ndsl.comm.boundary import Boundary +from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater from ndsl.performance.timer import NullTimer, Timer @@ -109,10 +109,13 @@ def _create_all_reduce_quantity( ) return all_reduce_quantity - def all_reduce_sum( - self, input_quantity: Quantity, output_quantity: Quantity = None + def all_reduce( + self, + input_quantity: Quantity, + op: ReductionOperator, + output_quantity: Quantity = None, ): - reduced_quantity_data = self.comm.allreduce(input_quantity.data, MPI.SUM) + reduced_quantity_data = self.comm.allreduce(input_quantity.data, op) if output_quantity is None: all_reduce_quantity = self._create_all_reduce_quantity( input_quantity.metadata, reduced_quantity_data @@ -126,6 +129,14 @@ def all_reduce_sum( output_quantity.data = reduced_quantity_data + def all_reduce_per_element( + self, + input_quantity: Quantity, + output_quantity: Quantity, + op: ReductionOperator, + ): + self.comm.Allreduce(input_quantity.data, output_quantity.data, op) + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( numpy_module.zeros, recvbuf diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 6f47c791..b3b834b4 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -1,10 +1,11 @@ try: + import mpi4py from mpi4py import MPI except ImportError: MPI = None -from typing import List, Optional, TypeVar, cast +from typing import Dict, List, Optional, TypeVar, cast -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request from ndsl.logging import ndsl_log @@ -12,6 +13,24 @@ class MPIComm(Comm): + _op_mapping: Dict[ReductionOperator, mpi4py.MPI.Op] = { + ReductionOperator.OP_NULL: mpi4py.MPI.OP_NULL, + ReductionOperator.MAX: mpi4py.MPI.MAX, + ReductionOperator.MIN: mpi4py.MPI.MIN, + ReductionOperator.SUM: mpi4py.MPI.SUM, + ReductionOperator.PROD: mpi4py.MPI.PROD, + ReductionOperator.LAND: mpi4py.MPI.LAND, + ReductionOperator.BAND: mpi4py.MPI.BAND, + ReductionOperator.LOR: mpi4py.MPI.LOR, + ReductionOperator.BOR: mpi4py.MPI.BOR, + ReductionOperator.LXOR: mpi4py.MPI.LXOR, + ReductionOperator.BXOR: mpi4py.MPI.BXOR, + ReductionOperator.MAXLOC: mpi4py.MPI.MAXLOC, + ReductionOperator.MINLOC: mpi4py.MPI.MINLOC, + ReductionOperator.REPLACE: mpi4py.MPI.REPLACE, + ReductionOperator.NO_OP: mpi4py.MPI.NO_OP, + } + def __init__(self): if MPI is None: raise RuntimeError("MPI not available") @@ -72,8 +91,14 @@ def Split(self, color, key) -> "Comm": ) return self._comm.Split(color, key) - def allreduce(self, sendobj: T, op=None) -> T: + def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ndsl_log.debug( "allreduce on rank %s with operator %s", self._comm.Get_rank(), op ) return self._comm.allreduce(sendobj, op) + + def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + ndsl_log.debug( + "allreduce on rank %s with operator %s", self._comm.Get_rank(), op + ) + return self._comm.Allreduce(sendobj, recvobj, self._op_mapping[op]) diff --git a/ndsl/comm/null_comm.py b/ndsl/comm/null_comm.py index 7e0c07fa..5ca92359 100644 --- a/ndsl/comm/null_comm.py +++ b/ndsl/comm/null_comm.py @@ -1,7 +1,7 @@ import copy -from typing import Any, Mapping +from typing import Any, Mapping, Optional -from ndsl.comm.comm_abc import Comm, Request +from ndsl.comm.comm_abc import Comm, ReductionOperator, Request class NullAsyncResult(Request): @@ -91,5 +91,9 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op: Optional[ReductionOperator] = None) -> Any: return self._fill_value + + def Allreduce(self, sendobj, recvobj, op: ReductionOperator) -> Any: + recvobj = sendobj + return recvobj From 0e8089eed9c909bf91cc1a117ecf12cf6cfe7397 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 10:45:06 -0500 Subject: [PATCH 12/18] Update utest --- tests/mpi/test_mpi_all_reduce_sum.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 858a7f94..4a15ad53 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -7,6 +7,7 @@ Quantity, TilePartitioner, ) +from ndsl.comm.comm_abc import ReductionOperator from ndsl.dsl.typing import Float from tests.mpi.mpi_comm import MPI @@ -48,10 +49,7 @@ def communicator(cube_partitioner): @pytest.mark.skipif( MPI is None, reason="mpi4py is not available or pytest was not run in parallel" ) -def test_all_reduce_sum( - communicator, -): - +def test_all_reduce(communicator): backends = ["dace:cpu", "gt:cpu_kfirst", "numpy"] for backend in backends: @@ -84,15 +82,15 @@ def test_all_reduce_sum( gt4py_backend=backend, ) - global_sum_q = communicator.all_reduce_sum(testQuantity_1D) + global_sum_q = communicator.all_reduce(testQuantity_1D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_1D.metadata assert (global_sum_q.data == (testQuantity_1D.data * communicator.size)).all() - global_sum_q = communicator.all_reduce_sum(testQuantity_2D) + global_sum_q = communicator.all_reduce(testQuantity_2D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_2D.metadata assert (global_sum_q.data == (testQuantity_2D.data * communicator.size)).all() - global_sum_q = communicator.all_reduce_sum(testQuantity_3D) + global_sum_q = communicator.all_reduce(testQuantity_3D, ReductionOperator.SUM) assert global_sum_q.metadata == testQuantity_3D.metadata assert (global_sum_q.data == (testQuantity_3D.data * communicator.size)).all() @@ -125,19 +123,25 @@ def test_all_reduce_sum( units="Some 3D unit", gt4py_backend=backend, ) - communicator.all_reduce_sum(testQuantity_1D, testQuantity_1D_out) + communicator.all_reduce( + testQuantity_1D, ReductionOperator.SUM, testQuantity_1D_out + ) assert testQuantity_1D_out.metadata == testQuantity_1D.metadata assert ( testQuantity_1D_out.data == (testQuantity_1D.data * communicator.size) ).all() - communicator.all_reduce_sum(testQuantity_2D, testQuantity_2D_out) + communicator.all_reduce( + testQuantity_2D, ReductionOperator.SUM, testQuantity_2D_out + ) assert testQuantity_2D_out.metadata == testQuantity_2D.metadata assert ( testQuantity_2D_out.data == (testQuantity_2D.data * communicator.size) ).all() - communicator.all_reduce_sum(testQuantity_3D, testQuantity_3D_out) + communicator.all_reduce( + testQuantity_3D, ReductionOperator.SUM, testQuantity_3D_out + ) assert testQuantity_3D_out.metadata == testQuantity_3D.metadata assert ( testQuantity_3D_out.data == (testQuantity_3D.data * communicator.size) From 2188c75cb2b005a76859332bcc906987e0495116 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 11:01:14 -0500 Subject: [PATCH 13/18] Fix `local_comm` --- ndsl/comm/local_comm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ndsl/comm/local_comm.py b/ndsl/comm/local_comm.py index 5ebfb47d..1ae10177 100644 --- a/ndsl/comm/local_comm.py +++ b/ndsl/comm/local_comm.py @@ -189,8 +189,14 @@ def Split(self, color, key): self._split_comms[color].append(new_comm) return new_comm - def allreduce(self, sendobj, op=None) -> Any: + def allreduce(self, sendobj, op=None, recvobj=None) -> Any: raise NotImplementedError( - "sendrecv fundamentally cannot be written for LocalComm, " + "allreduce fundamentally cannot be written for LocalComm, " + "as it requires synchronicity" + ) + + def Allreduce(self, sendobj, recvobj, op) -> Any: + raise NotImplementedError( + "Allreduce fundamentally cannot be written for LocalComm, " "as it requires synchronicity" ) From f8cc2ce97617ee68df87652c5d75abf16b718db8 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 11:08:38 -0500 Subject: [PATCH 14/18] Fix utest --- ndsl/comm/mpi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index b3b834b4..3873cc52 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -95,7 +95,7 @@ def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ndsl_log.debug( "allreduce on rank %s with operator %s", self._comm.Get_rank(), op ) - return self._comm.allreduce(sendobj, op) + return self._comm.allreduce(sendobj, self._op_mapping[op]) def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ndsl_log.debug( From 7ad271f878ef8c605cba02e07828ee596055eed0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 11:30:04 -0500 Subject: [PATCH 15/18] Enforce `comm_abc.Comm` into Communicator --- ndsl/comm/communicator.py | 30 +++++++++++++++++++--------- tests/mpi/test_mpi_all_reduce_sum.py | 3 ++- tests/mpi/test_mpi_halo_update.py | 5 +++-- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index ed1f264d..c952c022 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -6,6 +6,7 @@ import ndsl.constants as constants from ndsl.buffer import array_buffer, device_synchronize, recv_buffer, send_buffer from ndsl.comm.boundary import Boundary +from ndsl.comm.comm_abc import Comm as CommABC from ndsl.comm.comm_abc import ReductionOperator from ndsl.comm.partitioner import CubedSpherePartitioner, Partitioner, TilePartitioner from ndsl.halo.updater import HaloUpdater, HaloUpdateRequest, VectorInterfaceHaloUpdater @@ -45,7 +46,11 @@ def to_numpy(array, dtype=None) -> np.ndarray: class Communicator(abc.ABC): def __init__( - self, comm, partitioner, force_cpu: bool = False, timer: Optional[Timer] = None + self, + comm: CommABC, + partitioner, + force_cpu: bool = False, + timer: Optional[Timer] = None, ): self.comm = comm self.partitioner: Partitioner = partitioner @@ -62,7 +67,7 @@ def tile(self) -> "TileCommunicator": @abc.abstractmethod def from_layout( cls, - comm, + comm: CommABC, layout: Tuple[int, int], force_cpu: bool = False, timer: Optional[Timer] = None, @@ -138,15 +143,17 @@ def all_reduce_per_element( self.comm.Allreduce(input_quantity.data, output_quantity.data, op) def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: + with ( + send_buffer(numpy_module.zeros, sendbuf) as send, + recv_buffer(numpy_module.zeros, recvbuf) as recv, + ): self.comm.Scatter(send, recv, **kwargs) def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with send_buffer(numpy_module.zeros, sendbuf) as send, recv_buffer( - numpy_module.zeros, recvbuf - ) as recv: + with ( + send_buffer(numpy_module.zeros, sendbuf) as send, + recv_buffer(numpy_module.zeros, recvbuf) as recv, + ): self.comm.Gather(send, recv, **kwargs) def scatter( @@ -753,7 +760,7 @@ class CubedSphereCommunicator(Communicator): def __init__( self, - comm, + comm: CommABC, partitioner: CubedSpherePartitioner, force_cpu: bool = False, timer: Optional[Timer] = None, @@ -766,6 +773,11 @@ def __init__( force_cpu: Force all communication to go through central memory. timer: Time communication operations. """ + if not issubclass(type(comm), CommABC): + raise TypeError( + "Communictor needs to be instantiated with communication subsytem" + f" derived from `comm_abc.Comm`, got {type(comm)}." + ) if comm.Get_size() != partitioner.total_ranks: raise ValueError( f"was given a partitioner for {partitioner.total_ranks} ranks but a " diff --git a/tests/mpi/test_mpi_all_reduce_sum.py b/tests/mpi/test_mpi_all_reduce_sum.py index 4a15ad53..bec096dd 100644 --- a/tests/mpi/test_mpi_all_reduce_sum.py +++ b/tests/mpi/test_mpi_all_reduce_sum.py @@ -8,6 +8,7 @@ TilePartitioner, ) from ndsl.comm.comm_abc import ReductionOperator +from ndsl.comm.mpi import MPIComm from ndsl.dsl.typing import Float from tests.mpi.mpi_comm import MPI @@ -41,7 +42,7 @@ def cube_partitioner(tile_partitioner): @pytest.fixture() def communicator(cube_partitioner): return CubedSphereCommunicator( - comm=MPI.COMM_WORLD, + comm=MPIComm(), partitioner=cube_partitioner, ) diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index ab11b16e..1e6aaefc 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -8,6 +8,7 @@ Quantity, TilePartitioner, ) +from ndsl.comm.mpi import MPIComm from ndsl.comm._boundary_utils import get_boundary_slice from ndsl.constants import ( BOUNDARY_TYPES, @@ -39,7 +40,7 @@ def layout(): if MPI is not None: size = MPI.COMM_WORLD.Get_size() ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile ** 0.5) + ranks_per_edge = int(ranks_per_tile**0.5) return (ranks_per_edge, ranks_per_edge) else: return (1, 1) @@ -176,7 +177,7 @@ def extent(n_points, dims, nz, ny, nx): @pytest.fixture() def communicator(cube_partitioner): return CubedSphereCommunicator( - comm=MPI.COMM_WORLD, + comm=MPIComm(), partitioner=cube_partitioner, ) From 07cd0f32cb5e203ae98c0946a092c5edc9e7f7c0 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 11:58:52 -0500 Subject: [PATCH 16/18] Fix `comm` object in serial utest --- tests/dsl/test_compilation_config.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 62049d91..95ca7f74 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -9,6 +9,7 @@ CubedSpherePartitioner, RunMode, TilePartitioner, + NullComm, ) @@ -33,8 +34,7 @@ def test_check_communicator_valid( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int((sqrt(size / 6))))) ) - comm = unittest.mock.MagicMock() - comm.Get_size.return_value = size + comm = NullComm(rank=0, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -52,8 +52,7 @@ def test_check_communicator_invalid( nx: int, ny: int, use_minimal_caching: bool, run_mode: RunMode ): partitioner = CubedSpherePartitioner(TilePartitioner((nx, ny))) - comm = unittest.mock.MagicMock() - comm.Get_size.return_value = nx * ny * 6 + comm = NullComm(rank=0, total_ranks=nx * ny * 6) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig( run_mode=run_mode, use_minimal_caching=use_minimal_caching @@ -91,9 +90,7 @@ def test_get_decomposition_info_from_comm( partitioner = CubedSpherePartitioner( TilePartitioner((int(sqrt(size / 6)), int(sqrt(size / 6)))) ) - comm = unittest.mock.MagicMock() - comm.Get_rank.return_value = rank - comm.Get_size.return_value = size + comm = NullComm(rank=rank, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) config = CompilationConfig(use_minimal_caching=True, run_mode=RunMode.Run) ( @@ -133,8 +130,7 @@ def test_determine_compiling_equivalent( TilePartitioner((sqrt(size / 6), sqrt(size / 6))) ) comm = unittest.mock.MagicMock() - comm.Get_rank.return_value = rank - comm.Get_size.return_value = size + comm = NullComm(rank=rank, total_ranks=size) cubed_sphere_comm = CubedSphereCommunicator(comm, partitioner) assert ( config.determine_compiling_equivalent(rank, cubed_sphere_comm.partitioner) From 224e6e24ecfc09e65bf8f0ec1f9b3f3b0d4c7ed2 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Sun, 22 Dec 2024 13:50:37 -0500 Subject: [PATCH 17/18] Lint + `MPIComm` on testing architecture --- ndsl/comm/communicator.py | 16 ++++++---------- ndsl/stencils/testing/conftest.py | 11 +++++------ ndsl/stencils/testing/test_translate.py | 17 +++++++++-------- tests/dsl/test_compilation_config.py | 2 +- tests/mpi/test_mpi_halo_update.py | 4 ++-- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index c952c022..1ea4f5a3 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -143,18 +143,14 @@ def all_reduce_per_element( self.comm.Allreduce(input_quantity.data, output_quantity.data, op) def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): - with ( - send_buffer(numpy_module.zeros, sendbuf) as send, - recv_buffer(numpy_module.zeros, recvbuf) as recv, - ): - self.comm.Scatter(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Scatter(send, recv, **kwargs) def _Gather(self, numpy_module, sendbuf, recvbuf, **kwargs): - with ( - send_buffer(numpy_module.zeros, sendbuf) as send, - recv_buffer(numpy_module.zeros, recvbuf) as recv, - ): - self.comm.Gather(send, recv, **kwargs) + with send_buffer(numpy_module.zeros, sendbuf) as send: + with recv_buffer(numpy_module.zeros, recvbuf) as recv: + self.comm.Gather(send, recv, **kwargs) def scatter( self, diff --git a/ndsl/stencils/testing/conftest.py b/ndsl/stencils/testing/conftest.py index af5bb6a6..9810fb4a 100644 --- a/ndsl/stencils/testing/conftest.py +++ b/ndsl/stencils/testing/conftest.py @@ -13,7 +13,7 @@ CubedSphereCommunicator, TileCommunicator, ) -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.namelist import Namelist @@ -308,7 +308,7 @@ def compute_grid_data(grid, namelist, backend, layout, topology_mode): npx=namelist.npx, npy=namelist.npy, npz=namelist.npz, - communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode), + communicator=get_communicator(MPIComm(), layout, topology_mode), backend=backend, ) @@ -360,13 +360,12 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): metafunc.config ) # get MPI environment - comm = MPI.COMM_WORLD - mpi_rank = comm.Get_rank() + comm = MPIComm() savepoint_cases = parallel_savepoint_cases( metafunc, data_path, namelist_filename, - mpi_rank, + comm.Get_rank(), backend=backend, comm=comm, ) @@ -376,7 +375,7 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str): def get_communicator(comm, layout, topology_mode): - if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "cubed-sphere"): + if (comm.Get_size() > 1) and (topology_mode == "cubed-sphere"): partitioner = CubedSpherePartitioner(TilePartitioner(layout)) communicator = CubedSphereCommunicator(comm, partitioner) else: diff --git a/ndsl/stencils/testing/test_translate.py b/ndsl/stencils/testing/test_translate.py index db8e6047..64ae5f62 100644 --- a/ndsl/stencils/testing/test_translate.py +++ b/ndsl/stencils/testing/test_translate.py @@ -8,7 +8,7 @@ import ndsl.dsl.gt4py_utils as gt_utils from ndsl.comm.communicator import CubedSphereCommunicator, TileCommunicator -from ndsl.comm.mpi import MPI +from ndsl.comm.mpi import MPI, MPIComm from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner from ndsl.dsl.dace.dace_config import DaceConfig from ndsl.dsl.stencil import CompilationConfig, StencilConfig @@ -288,18 +288,19 @@ def test_parallel_savepoint( multimodal_metric, xy_indices=True, ): - if MPI.COMM_WORLD.Get_size() % 6 != 0: + mpi_comm = MPIComm() + if mpi_comm.Get_size() % 6 != 0: layout = ( - int(MPI.COMM_WORLD.Get_size() ** 0.5), - int(MPI.COMM_WORLD.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), + int(mpi_comm.Get_size() ** 0.5), ) - communicator = get_tile_communicator(MPI.COMM_WORLD, layout) + communicator = get_tile_communicator(mpi_comm, layout) else: layout = ( - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), - int((MPI.COMM_WORLD.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), + int((mpi_comm.Get_size() // 6) ** 0.5), ) - communicator = get_communicator(MPI.COMM_WORLD, layout) + communicator = get_communicator(mpi_comm, layout) if case.testobj is None: pytest.xfail( f"no translate object available for savepoint {case.savepoint_name}" diff --git a/tests/dsl/test_compilation_config.py b/tests/dsl/test_compilation_config.py index 95ca7f74..fa323b06 100644 --- a/tests/dsl/test_compilation_config.py +++ b/tests/dsl/test_compilation_config.py @@ -7,9 +7,9 @@ CompilationConfig, CubedSphereCommunicator, CubedSpherePartitioner, + NullComm, RunMode, TilePartitioner, - NullComm, ) diff --git a/tests/mpi/test_mpi_halo_update.py b/tests/mpi/test_mpi_halo_update.py index 1e6aaefc..b6c38e95 100644 --- a/tests/mpi/test_mpi_halo_update.py +++ b/tests/mpi/test_mpi_halo_update.py @@ -8,8 +8,8 @@ Quantity, TilePartitioner, ) -from ndsl.comm.mpi import MPIComm from ndsl.comm._boundary_utils import get_boundary_slice +from ndsl.comm.mpi import MPIComm from ndsl.constants import ( BOUNDARY_TYPES, EDGE_BOUNDARY_TYPES, @@ -40,7 +40,7 @@ def layout(): if MPI is not None: size = MPI.COMM_WORLD.Get_size() ranks_per_tile = size // 6 - ranks_per_edge = int(ranks_per_tile**0.5) + ranks_per_edge = int(ranks_per_tile ** 0.5) return (ranks_per_edge, ranks_per_edge) else: return (1, 1) From 760578c4e8cf505c96e030f521eb3addf463b871 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 30 Dec 2024 11:21:39 -0500 Subject: [PATCH 18/18] Add in_place option for Allreduce --- ndsl/comm/comm_abc.py | 3 +++ ndsl/comm/communicator.py | 5 +++++ ndsl/comm/mpi.py | 14 +++++++++++--- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/ndsl/comm/comm_abc.py b/ndsl/comm/comm_abc.py index 8192560b..45596f1e 100644 --- a/ndsl/comm/comm_abc.py +++ b/ndsl/comm/comm_abc.py @@ -95,3 +95,6 @@ def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: @abc.abstractmethod def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: ... + + def Allreduce_inplace(self, obj: T, op: ReductionOperator) -> T: + ... diff --git a/ndsl/comm/communicator.py b/ndsl/comm/communicator.py index 1ea4f5a3..ba980d19 100644 --- a/ndsl/comm/communicator.py +++ b/ndsl/comm/communicator.py @@ -142,6 +142,11 @@ def all_reduce_per_element( ): self.comm.Allreduce(input_quantity.data, output_quantity.data, op) + def all_reduce_per_element_in_place( + self, quantity: Quantity, op: ReductionOperator + ): + self.comm.Allreduce_inplace(quantity.data, op) + def _Scatter(self, numpy_module, sendbuf, recvbuf, **kwargs): with send_buffer(numpy_module.zeros, sendbuf) as send: with recv_buffer(numpy_module.zeros, recvbuf) as recv: diff --git a/ndsl/comm/mpi.py b/ndsl/comm/mpi.py index 3873cc52..6b3ff17f 100644 --- a/ndsl/comm/mpi.py +++ b/ndsl/comm/mpi.py @@ -97,8 +97,16 @@ def allreduce(self, sendobj: T, op: Optional[ReductionOperator] = None) -> T: ) return self._comm.allreduce(sendobj, self._op_mapping[op]) - def Allreduce(self, sendobj: T, recvobj: T, op: ReductionOperator) -> T: + def Allreduce(self, sendobj_or_inplace: T, recvobj: T, op: ReductionOperator) -> T: ndsl_log.debug( - "allreduce on rank %s with operator %s", self._comm.Get_rank(), op + "Allreduce on rank %s with operator %s", self._comm.Get_rank(), op + ) + return self._comm.Allreduce(sendobj_or_inplace, recvobj, self._op_mapping[op]) + + def Allreduce_inplace(self, recvobj: T, op: ReductionOperator) -> T: + ndsl_log.debug( + "Allreduce (in place) on rank %s with operator %s", + self._comm.Get_rank(), + op, ) - return self._comm.Allreduce(sendobj, recvobj, self._op_mapping[op]) + return self._comm.Allreduce(mpi4py.MPI.IN_PLACE, recvobj, self._op_mapping[op])