Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/frontend/defir_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]:
loc=location_to_source_location(node.loc),
)

def visit_HorizontalIf(self, node: HorizontalIf) -> gtir.FieldIfStmt:
def visit_HorizontalIf(self, node: HorizontalIf) -> gtir.HorizontalRestriction:
def make_bound_or_level(bound: AxisBound, level) -> Optional[common.AxisBound]:
if (level == LevelMarker.START and bound.offset <= -10000) or (
level == LevelMarker.END and bound.offset >= 10000
Expand Down
5 changes: 2 additions & 3 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,8 +1140,7 @@ def visit_Name(self, node: ast.Name) -> nodes.Ref:
raise AssertionError(f"Missing '{symbol}' symbol definition")

def visit_Index(self, node: ast.Index):
index = self.visit(node.value)
return index
return self.visit(node.value)

def _eval_new_spatial_index(
self, index_nodes: Sequence[nodes.Expr], field_axes: Optional[Set[Literal["I", "J", "K"]]]
Expand Down Expand Up @@ -2364,7 +2363,7 @@ def run(self, backend_name: str):
parameters=[
parameter_decls[item.name] for item in api_signature if item.name in parameter_decls
],
computations=init_computations + computations if init_computations else computations,
computations=init_computations + computations,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_computations comes from _make_init_compuations() above which will always return a list. The returned list might be empty if there are no "init computations". In that case, the concatenation still works. Imo no need to have an if / else here.

externals=self.resolved_externals,
docstring=inspect.getdoc(self.definition) or "",
loc=nodes.Location.from_ast_node(self.ast_root.body[0]),
Expand Down
91 changes: 55 additions & 36 deletions src/gt4py/cartesian/gtc/gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,14 @@ def no_write_and_read_with_offset_of_same_field(
) -> None:
if isinstance(instance.left, FieldAccess):
offset_reads = (
(
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_absolutekindex_fieldaccess)
.getattr("name")
.to_set()
)
| (
eve.walk_values(instance.right)
.filter(_variablek_fieldaccess)
.getattr("name")
.to_set()
)
eve.walk_values(instance.right)
.filter(_cartesian_fieldaccess)
.filter(lambda acc: acc.offset.i != 0 or acc.offset.j != 0)
.getattr("name")
.to_set()
)
if instance.left.name in offset_reads:
raise ValueError("Self-assignment with offset is illegal.")
raise ValueError("Self-assignment with offset in I or J is illegal.")

_dtype_validation = common.assign_stmt_dtype_validation(strict=False)

Expand Down Expand Up @@ -250,6 +236,55 @@ def _no_write_and_read_with_horizontal_offset(
f"Illegal write and read with horizontal offset detected for {non_tmp_fields}."
)

@datamodels.root_validator
@classmethod
def _vertical_offset_in_parallel(cls: type[VerticalLoop], instance: VerticalLoop) -> None:
"""
In a parallel vertical loop we disallow writing and reading the same field with a non-zero offset.

To write and read with non-zero offset creates a race condition in parallel. There's an
exception for vertical loops with size one, e.g. `interval(0, 1)`.
"""

def _size_one(interval: Interval) -> bool:
if interval.start.level != interval.end.level:
# if the levels (start/end) aren't the same, we don't know at this stage
return False

return abs(interval.end.offset - interval.start.offset) == 1

if instance.loop_order != common.LoopOrder.PARALLEL or _size_one(instance.interval):
return

# gather all writes as a mapping of id(node) -> node
writes: dict[int, FieldAccess] = dict()
for left in eve.walk_values(instance.body).if_isinstance(ParAssignStmt).getattr("left"):
if isinstance(left, FieldAccess):
writes[id(left)] = left

# check that we don't have a write and reads of the same field with non-zero offsets
for node in eve.walk_values(instance.body).if_isinstance(FieldAccess):
if id(node) in writes:
# this is the write access - skip it
continue

for write_access in writes.values():
if node.name == write_access.name:
if isinstance(node.offset, (VariableKOffset, AbsoluteKIndex)) or isinstance(
write_access.offset, (VariableKOffset, AbsoluteKIndex)
):
raise ValueError(
"Not allowed to write and read with `VariableKOffset` and/or "
f"`AbsoluteKIndex` in PARALLEL loops: `{node.name}`"
)

# For cartesian offsets, we allow it if both offsets are equal (e.g. 0)
if node.offset.k != write_access.offset.k:
raise ValueError(
"Not allowed to write and read with k-offsets in PARALLEL "
f"loops: `{node.name}`"
)


class Argument(eve.Node):
name: str
Expand Down Expand Up @@ -281,22 +316,6 @@ def _cartesian_fieldaccess(node) -> bool:
)


def _variablek_fieldaccess(node) -> bool:
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, VariableKOffset)
and not isinstance(node.offset, AbsoluteKIndex)
)


def _absolutekindex_fieldaccess(node) -> bool:
return (
isinstance(node, FieldAccess)
and isinstance(node.offset, AbsoluteKIndex)
and not isinstance(node.offset, VariableKOffset)
)


# TODO(havogt): either move to eve or will be removed in the attr-based eve if a List[Node] is represented as a CollectionNode
def _written_and_read_with_offset(stmts: List[Stmt]) -> Set[str]:
"""Return a list of names that are written to and read with offset."""
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/eve/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class EveError:
This base class has to be always inherited together with a standard
exception, and thus it should not be used as direct superclass
for custom exceptions. Inherit directly from :class:`EveTypeError`,
:class:`EveTypeError`, etc. instead.
:class:`EveValueError`, etc. instead.

"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1472,3 +1472,97 @@ def test_upcasting_stencil(

test_upcasting_stencil(input, index_array, output)
assert (input == output).all()


def test_no_write_and_read_with_horizontal_offset() -> None:
with pytest.raises(ValueError, match="Self-assignment with offset in I or J is illegal."):

@gtscript.stencil(backend="debug")
def self_assign_offset(field: Field[np.float64]) -> None:
with computation(PARALLEL), interval(...):
field = (field[I - 1] + field[I + 1]) / 2

with pytest.raises(ValueError, match="Illegal write and read with horizontal offset"):

@gtscript.stencil(backend="debug")
def self_assign_offset(field: Field[np.float64]) -> None:
with computation(PARALLEL), interval(...):
tmp = (field[J - 1] + field[J + 1]) / 2
field = tmp * 2


def test_k_offsets_in_parallel_loops() -> None:
with pytest.raises(ValueError, match="write and read with k-offsets in PARALLEL"):

@gtscript.stencil(backend="debug")
def self_assign_offset_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
field = field[K - 1] * 2

with pytest.raises(ValueError, match="write and read with k-offsets in PARALLEL"):

@gtscript.stencil(backend="debug")
def self_assign_offset_parallel_temp(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
tmp = field[K - 1]
field = tmp * 2

with pytest.raises(
ValueError, match="write and read with `VariableKOffset` and/or `AbsoluteKIndex`"
):

@gtscript.stencil(backend="debug")
def mixed_read_write(field: Field[np.int32]):
with computation(PARALLEL), interval(...):
level = field.at(K=1)
field = 2 * level

with pytest.raises(
ValueError, match="write and read with `VariableKOffset` and/or `AbsoluteKIndex`"
):

@gtscript.stencil(backend="debug")
def mixed_read_write(field: Field[np.int32], offset: int = -1):
with computation(PARALLEL), interval(1, None):
bottom = field[0, 0, offset]
field = field + 2 * bottom

# center reads and writes are allowed
@gtscript.stencil(backend="debug")
def self_assignment_center_read_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(...):
field = field[0, 0, 0] * 2

@gtscript.stencil(backend="debug")
def self_assignment_center_write_parallel(field: Field[np.int32]) -> None:
with computation(PARALLEL), interval(...):
field[0, 0, 0] = field * 2

# not mixing reads and writes are allowed (e.g. index fields)
@gtscript.stencil(backend="debug")
def self_assignment_center_parallel(field: Field[np.float32], index: Field[np.int32]) -> None:
with computation(PARALLEL), interval(1, None):
field = index + index[K - 1] * 2

# parallel intervals of static size 1 are allowed
@gtscript.stencil(backend="debug")
def the_stencil(field: Field[np.bool_]) -> None:
with computation(PARALLEL):
with interval(0, 1):
field = field[K + 1]
with interval(-1, None):
field = field[K - 1]


@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_self_assignment_in_forward(backend: str) -> None:
@gtscript.stencil(backend=backend)
def self_assignment_parallel(field: Field[np.int32]) -> None:
with computation(FORWARD), interval(1, None):
field = field[K - 1] * 2

@gtscript.stencil(backend=backend)
def self_assignment_2_parallel(field: Field[np.int32]) -> None:
with computation(FORWARD), interval(1, None):
tmp = field[K - 1]
field = tmp * 2
Comment thread
FlorianDeconinck marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,7 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b:
)
def test_invalid_temporary_access(definition):
builder = StencilBuilder(definition, backend=from_name("numpy"))
with pytest.raises(TypeError, match="Invalid access with offset in k to temporary field tmp."):
with pytest.raises(
ValueError, match="Not allowed to write and read with k-offsets in PARALLEL loops: `tmp`"
):
k_boundary = compute_k_boundary(builder.gtir_pipeline.full(skip=[prune_unused_parameters]))