Skip to content

Commit

Permalink
Fix dace backend using program_to_fencil transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Sep 17, 2024
1 parent bfd8e16 commit 160a616
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 2 deletions.
31 changes: 31 additions & 0 deletions src/gt4py/next/iterator/transforms/program_to_fencil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm


def program_to_fencil(node: itir.Program) -> itir.FencilDefinition:
assert not node.declarations
closures = []
for stmt in node.body:
assert isinstance(stmt, itir.SetAt)
assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop")
stencil, domain = stmt.expr.fun.args
inputs = stmt.expr.args
assert all(isinstance(inp, itir.SymRef) for inp in inputs)
closures.append(
itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs)
)

return itir.FencilDefinition(
id=node.id,
function_definitions=node.function_definitions,
params=node.params,
closures=closures,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from gt4py.next import common
from gt4py.next.ffront import decorator
from gt4py.next.iterator import transforms as itir_transforms
from gt4py.next.iterator.transforms import program_to_fencil
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_specifications as ts

Expand Down Expand Up @@ -95,7 +96,7 @@ def preprocess_program(
node = itir_type_inference.infer(node, offset_provider=offset_provider)

if isinstance(node, itir.Program):
fencil_definition = node
fencil_definition = program_to_fencil.program_to_fencil(node)
tmps = node.declarations
assert all(isinstance(tmp, itir.Temporary) for tmp in tmps)
else:
Expand Down Expand Up @@ -388,7 +389,8 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG:
itir_tmp = itir_transforms.apply_common_transforms(
self.itir, offset_provider=offset_provider
)
for closure in itir_tmp.closures: # type: ignore[union-attr]
itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp)
for closure in itir_tmp_fencil.closures:
shifts = itir_transforms.trace_shifts.TraceShifts.apply(closure)
for k, v in shifts.items():
if not isinstance(k, str):
Expand Down

0 comments on commit 160a616

Please sign in to comment.