From 4897cbf53ee83c60a369272676b8e731acf1c986 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 25 Aug 2025 23:31:42 +0200 Subject: [PATCH] Move test for GridIndexing.get_2d_compute_origin_domain() This PR is a follow-up from PR #54. It moves and simplifies the test for GridIndexing.get_2d_compute_origin_domain(). This test does not need a stencil fectory and also doesn't depend on a {Stencil,Dace}Config. --- tests/dsl/test_stencil.py | 24 ++++++++++++++++++ tests/dsl/test_stencil_factory.py | 42 ------------------------------- 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/tests/dsl/test_stencil.py b/tests/dsl/test_stencil.py index 19742b49..28db6a79 100644 --- a/tests/dsl/test_stencil.py +++ b/tests/dsl/test_stencil.py @@ -1,3 +1,4 @@ +import pytest from gt4py.storage import empty, ones from ndsl import CompilationConfig, GridIndexing, StencilConfig, StencilFactory @@ -37,3 +38,26 @@ def func(inp: Field[float], out: Field[float]): test(inp, out) exec_report = stencil_factory.exec_report() assert "func" in exec_report + + +@pytest.mark.parametrize("klevel,expected_origin_k", [(None, 0), (1, 1), (30, 30)]) +def test_grid_indexing_get_2d_compute_origin_domain( + klevel: int | None, + expected_origin_k: int, +): + indexing = GridIndexing( + domain=(12, 12, 79), + n_halo=3, + south_edge=True, + north_edge=True, + west_edge=True, + east_edge=True, + ) + + if klevel is None: + origin, domain = indexing.get_2d_compute_origin_domain() + else: + origin, domain = indexing.get_2d_compute_origin_domain(klevel) + + assert origin[2] == expected_origin_k + assert domain[2] == 1 diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index c9c1f51f..ce9de962 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -202,48 +202,6 @@ def test_stencil_factory_numpy_comparison_from_origin_domain( assert isinstance(stencil, FrozenStencil) -@pytest.mark.parametrize("enabled", [True, False]) -@pytest.mark.parametrize("klevel", [0, 1, 30]) -def test_stencil_factory_numpy_comparison_from_origin_domain_2d( - enabled: bool, klevel: int -): - backend = "numpy" - dace_config = DaceConfig(communicator=None, backend=backend) - config = StencilConfig( - compilation_config=CompilationConfig( - backend=backend, - rebuild=False, - validate_args=False, - format_source=False, - device_sync=False, - ), - compare_to_numpy=enabled, - dace_config=dace_config, - ) - indexing = GridIndexing( - domain=(12, 12, 79), - n_halo=3, - south_edge=True, - north_edge=True, - west_edge=True, - east_edge=True, - ) - if klevel > 0: - origin, domain = indexing.get_2d_compute_origin_domain(klevel=klevel) - else: - origin, domain = indexing.get_2d_compute_origin_domain() - assert domain[2] == 1 - assert origin[2] == klevel - factory = StencilFactory(config=config, grid_indexing=indexing) - stencil = factory.from_origin_domain( - func=copy_stencil, origin=origin, domain=domain - ) - if enabled: - assert isinstance(stencil, CompareToNumpyStencil) - else: - assert isinstance(stencil, FrozenStencil) - - @pytest.mark.parametrize("backend", BACKENDS) def test_stencil_factory_numpy_comparison_runs_without_exceptions(backend: str) -> None: dace_config = DaceConfig(communicator=None, backend=backend)