Skip to content
Merged
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]:
loc=location_to_source_location(node.loc),
)

def visit_HorizontalIf(self, node: HorizontalIf) -> gtir.FieldIfStmt:
def visit_HorizontalIf(self, node: HorizontalIf) -> gtir.HorizontalRestriction:
def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]:
if (level == LevelMarker.START and bound.offset <= -10000) or (
level == LevelMarker.END and bound.offset >= 10000
Expand Down
5 changes: 2 additions & 3 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,8 +1140,7 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:
raise AssertionError(f"Missing '{symbol}' symbol definition")

def visit_Index(self, node: ast.Index):
index = self.visit(node.value)
return index
return self.visit(node.value)

def _eval_new_spatial_index(
self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]]
Expand Down Expand Up @@ -2364,7 +2363,7 @@ def run(self, backend_name: str):
parameters=[
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
],
computations=init_computations + computations if init_computations else computations,
computations=init_computations + computations,

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_computations comes from _make_init_compuations() above which will always return a list. The returned list might be empty if there are no "init computations". In that case, the concatenation still works. Imo no need to have an if / else here.

externals=self.resolved_externals,
docstring=inspect.getdoc(self.definition) or "",
loc=nodes.Location.from_ast_node(self.ast_root.body[0]),
Expand Down
76 changes: 41 additions & 35 deletions src/gt4py/cartesian/gtc/gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,25 +99,11 @@ def no_write_and_read_with_offset_of_same_field(
) -> None:
if isinstance(instance.left, FieldAccess):
offset_reads = (
(
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_absolutekindex_fieldaccess)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_variablek_fieldaccess)
.getattr("name")
.to_set()
)
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
)
if instance.left.name in offset_reads:
raise ValueError("Self-assignment with offset is illegal.")
Comment thread
romanc marked this conversation as resolved.
Outdated
Expand Down Expand Up @@ -250,6 +236,42 @@ def _no_write_and_read_with_horizontal_offset(
f"Illegal write and read with horizontal offset detected for {non_tmp_fields}."
)

@datamodels.root_validator
@classmethod
def _vertical_offset_in_parallel(cls: type[VerticalLoop], instance: VerticalLoop) -> None:
"""
Read and write of the same field in a parallel loop with non-zero offsets is not allowed in parallel.
"""
if instance.loop_order != common.LoopOrder.PARALLEL:
return

# gather all writes as a mapping of id(node) -> node
writes: dict[int, FieldAccess] = dict()
for left in eve.walk_values(instance.body).if_isinstance(ParAssignStmt).getattr("left"):
if isinstance(left, FieldAccess):
writes[id(left)] = left

# check that we don't have read/write of the same field with non-zero offsets
for node in eve.walk_values(instance.body).if_isinstance(FieldAccess):
if id(node) in writes:
# this is the write access - skip it
continue

for write_access in writes.values():
if node.name == write_access.name:
if isinstance(node.offset, (VariableKOffset, AbsoluteKIndex)) or isinstance(
write_access.offset, (VariableKOffset, AbsoluteKIndex)
):
raise ValueError(
"Not allowed to read and write with `VariableKOffset` and/or `AbsoluteKIndex` in PARALLEL loops."
)

# For cartesian offsets, we allow it if both are equal (e.g. 0)
if node.offset.k != write_access.offset.k:
raise ValueError(
"Not allowed to read and write with k-offsets in PARALLEL loops."
)


class Argument(eve.Node):
name: str
Expand Down Expand Up @@ -281,22 +303,6 @@ def _cartesian_fieldaccess(node) -> bool:
)


def _variablek_fieldaccess(node) -> bool:
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, VariableKOffset)
and not isinstance(node.offset, AbsoluteKIndex)
)


def _absolutekindex_fieldaccess(node) -> bool:
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, AbsoluteKIndex)
and not isinstance(node.offset, VariableKOffset)
)


# TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode
def _written_and_read_with_offset(stmts: List[Stmt]) -> Set[str]:
"""Return a list of names that are written to and read with offset."""
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/eve/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class EveError:
This base class has to be always inherited together with a standard
exception, and thus it should not be used as direct superclass
for custom exceptions. Inherit directly from :class:`EveTypeError`,
:class:`EveTypeError`, etc. instead.
:class:`EveValueError`, etc. instead.

"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1472,3 +1472,71 @@ def test_upcasting_stencil(

test_upcasting_stencil(input, index_array, output)
assert (input == output).all()


def test_k_offsets_in_parallel_loops() -> None:
with pytest.raises(ValueError, match="read and write with k-offsets in PARALLEL"):

@gtscript.stencil(backend="debug")
def self_assign_offset_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
field = field[0, 0, -1] * 2

with pytest.raises(ValueError, match="read and write with k-offsets in PARALLEL"):

@gtscript.stencil(backend="debug")
def self_assign_offset_parallel_temp(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
tmp = field[0, 0, -1]
field = tmp * 2

with pytest.raises(
ValueError, match="read and write with `VariableKOffset` and/or `AbsoluteKIndex`"
):

@gtscript.stencil(backend="debug")
def mixed_read_write(field: Field[np.int32]):
with computation(PARALLEL), interval(...):
level = field.at(K=1)
field = 2 * level

with pytest.raises(
ValueError, match="read and write with `VariableKOffset` and/or `AbsoluteKIndex`"
):

@gtscript.stencil(backend="debug")
def mixed_read_write(field: Field[np.int32], offset: int = -1):
with computation(PARALLEL), interval(1, None):
bottom = field[0, 0, offset]
field = field + 2 * bottom

# center reads and writes are allowed
@gtscript.stencil(backend="debug")
def self_assignment_center_read_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(...):
field = field[0, 0, 0] * 2

@gtscript.stencil(backend="debug")
def self_assignment_center_write_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(...):
field[0, 0, 0] = field * 2

# not mixing reads and writes are allowed too (e.g. index fields)
@gtscript.stencil(backend="debug")
def self_assignment_center_parallel(field: Field[np.float32], index: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
field = index + index[0, 0, -1] * 2


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_self_assignment_in_forward(backend: str) -> None:
@gtscript.stencil(backend=backend)
def self_assignment_parallel(field: Field[np.int32]) -> None:
with computation(FORWARD), interval(1, None):
field = field[0, 0, -1] * 2

@gtscript.stencil(backend=backend)
def self_assignment_2_parallel(field: Field[np.int32]) -> None:
with computation(FORWARD), interval(1, None):
tmp = field[0, 0, -1]
field = tmp * 2
Comment thread
FlorianDeconinck marked this conversation as resolved.
Loading