Skip to content

Commit

Permalink
refact[next][dace]: Helper function for field operator constructor (#…
Browse files Browse the repository at this point in the history
…1743)

Includes refactoring of the code for construction of field operators, in
order to make it usable by the three lowering functions that construct
fields: `translate_as_fieldop()`, `translate_broadcast_scalar()`, and
`translate_index()`.
  • Loading branch information
edopao authored Nov 29, 2024
1 parent 3ece412 commit 8860584
Showing 1 changed file with 94 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils
from gt4py.next.iterator.ir_utils import (
common_pattern_matcher as cpm,
domain_utils,
ir_makers as im,
)
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand Down Expand Up @@ -229,40 +233,75 @@ def _get_field_layout(
return list(domain_dims), list(domain_lbs), domain_sizes


def _create_temporary_field(
def _create_field_operator(
sdfg: dace.SDFG,
state: dace.SDFGState,
domain: FieldopDomain,
node_type: ts.FieldType,
dataflow_output: gtir_dataflow.DataflowOutputEdge,
sdfg_builder: gtir_sdfg.SDFGBuilder,
input_edges: Sequence[gtir_dataflow.DataflowInputEdge],
output_edge: gtir_dataflow.DataflowOutputEdge,
) -> FieldopData:
"""Helper method to allocate a temporary field where to write the output of a field operator."""
"""
Helper method to allocate a temporary field to store the output of a field operator.
Args:
sdfg: The SDFG that represents the scope of the field data.
state: The SDFG state where to create an access node to the field data.
domain: The domain of the field operator that computes the field.
node_type: The GT4Py type of the IR node that produces this field.
sdfg_builder: The object used to build the map scope in the provided SDFG.
input_edges: List of edges to pass input data into the dataflow.
output_edge: Edge representing the dataflow output data.
Returns:
The field data descriptor, which includes the field access node in the given `state`
and the field domain offset.
"""
field_dims, field_offset, field_shape = _get_field_layout(domain)
field_indices = _get_domain_indices(field_dims, field_offset)

dataflow_output_desc = output_edge.result.dc_node.desc(sdfg)

output_desc = dataflow_output.result.dc_node.desc(sdfg)
if isinstance(output_desc, dace.data.Array):
field_subset = sbs.Range.from_indices(field_indices)
if isinstance(output_edge.result.gt_dtype, ts.ScalarType):
assert output_edge.result.gt_dtype == node_type.dtype
assert isinstance(dataflow_output_desc, dace.data.Scalar)
assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype)
field_dtype = output_edge.result.gt_dtype
else:
assert isinstance(node_type.dtype, itir_ts.ListType)
assert isinstance(node_type.dtype.element_type, ts.ScalarType)
assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type)
assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type
assert isinstance(dataflow_output_desc, dace.data.Array)
assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType)
field_dtype = output_edge.result.gt_dtype.element_type
# extend the array with the local dimensions added by the field operator (e.g. `neighbors`)
field_offset.extend(output_desc.offset)
field_shape.extend(output_desc.shape)
elif isinstance(output_desc, dace.data.Scalar):
assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype)
else:
raise ValueError(f"Cannot create field for dace type {output_desc}.")
assert output_edge.result.gt_dtype.offset_type is not None
field_dims.append(output_edge.result.gt_dtype.offset_type)
field_shape.extend(dataflow_output_desc.shape)
field_offset.extend(dataflow_output_desc.offset)
field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc)

# allocate local temporary storage
temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype)
field_node = state.add_access(temp_name)
field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype)
field_node = state.add_access(field_name)

if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType):
field_dtype = dataflow_output.result.gt_dtype
else:
assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType)
field_dtype = dataflow_output.result.gt_dtype.element_type
assert dataflow_output.result.gt_dtype.offset_type is not None
field_dims.append(dataflow_output.result.gt_dtype.offset_type)
# create map range corresponding to the field operator domain
me, mx = sdfg_builder.add_map(
"fieldop",
state,
ndrange={
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
)

# here we setup the edges passing through the map entry node
for edge in input_edges:
edge.connect(me)

# and here the edge writing the dataflow result data through the map exit node
output_edge.connect(mx, field_node, field_subset)

return FieldopData(
field_node,
Expand Down Expand Up @@ -341,125 +380,27 @@ def translate_as_fieldop(
# Special usage of 'deref' as argument to fieldop expression, to pass a scalar
# value to 'as_fieldop' function. It results in broadcasting the scalar value
# over the field domain.
return translate_broadcast_scalar(node, sdfg, state, sdfg_builder)
stencil_expr = im.lambda_("a")(im.deref("a"))
stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined]
else:
raise NotImplementedError(
f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node."
)

# parse the domain of the field operator
domain = extract_domain(domain_expr)
domain_dims, domain_offsets, _ = zip(*domain)
domain_indices = _get_domain_indices(domain_dims, domain_offsets)

# visit the list of arguments to be passed to the lambda expression
stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args]

# represent the field operator as a mapped tasklet graph, which will range over the field domain
taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder)
input_edges, output = taskgen.visit(stencil_expr, args=stencil_args)
output_desc = output.result.dc_node.desc(sdfg)

if isinstance(node.type.dtype, itir_ts.ListType):
assert isinstance(output_desc, dace.data.Array)
# additional local dimension for neighbors
# TODO(phimuell): Investigate if we should swap the two.
output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc)
else:
assert isinstance(output_desc, dace.data.Scalar)
output_subset = sbs.Range.from_indices(domain_indices)

# create map range corresponding to the field operator domain
me, mx = sdfg_builder.add_map(
"fieldop",
state,
ndrange={
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
)

# allocate local temporary storage for the result field
result_field = _create_temporary_field(sdfg, state, domain, node.type, output)

# here we setup the edges from the map entry node
for edge in input_edges:
edge.connect(me)

# and here the edge writing the result data through the map exit node
output.connect(mx, result_field.dc_node, output_subset)

return result_field

input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args)

def translate_broadcast_scalar(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_sdfg.SDFGBuilder,
) -> FieldopResult:
"""
Generates the dataflow subgraph for the 'as_fieldop' builtin function for the
special case where the argument to 'as_fieldop' is a 'deref' scalar expression,
rather than a lambda function. This case corresponds to broadcasting the scalar
value over the field domain. Therefore, it is lowered to a mapped tasklet that
just writes the scalar value out to all elements of the result field.
"""
assert isinstance(node, gtir.FunCall)
assert cpm.is_call_to(node.fun, "as_fieldop")
assert isinstance(node.type, ts.FieldType)

fun_node = node.fun
assert len(fun_node.args) == 2
stencil_expr, domain_expr = fun_node.args
assert cpm.is_ref_to(stencil_expr, "deref")

domain = extract_domain(domain_expr)
output_dims, output_offset, output_shape = _get_field_layout(domain)
output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset))

assert len(node.args) == 1
scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain)

if isinstance(node.args[0].type, ts.ScalarType):
assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr))
input_subset = (
str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0"
)
input_node = scalar_expr.dc_node
gt_dtype = node.args[0].type
elif isinstance(node.args[0].type, ts.FieldType):
assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr)
if len(node.args[0].type.dims) == 0: # zero-dimensional field
input_subset = "0"
else:
input_subset = scalar_expr.get_memlet_subset(sdfg)

input_node = scalar_expr.field
gt_dtype = node.args[0].type.dtype
else:
raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.")

output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
"broadcast",
state,
map_ranges={
dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}"
for dim, lower_bound, upper_bound in domain
},
inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)},
code="__val = __inp",
outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)},
input_nodes={input_node.data: input_node},
output_nodes={output_node.data: output_node},
external_edges=True,
return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge
)

return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset)


def translate_if(
node: gtir.Node,
Expand Down Expand Up @@ -567,38 +508,44 @@ def translate_index(
index values to a transient array. The extent of the index range is taken from
the domain information that should be present in the node annex.
"""
assert cpm.is_call_to(node, "index")
assert isinstance(node.type, ts.FieldType)

assert "domain" in node.annex
domain = extract_domain(node.annex.domain)
assert len(domain) == 1
dim, lower_bound, upper_bound = domain[0]
dim, _, _ = domain[0]
dim_index = dace_gtir_utils.get_map_variable(dim)

field_dims, field_offset, field_shape = _get_field_layout(domain)
field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE))

output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
index_data = sdfg.temp_data_name()
sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True)
index_node = state.add_access(index_data)
index_value = gtir_dataflow.ValueExpr(
dc_node=index_node,
gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE),
)
index_write_tasklet = sdfg_builder.add_tasklet(
"index",
state,
map_ranges={
dim_index: f"{lower_bound}:{upper_bound}",
},
inputs={},
outputs={"__val"},
code=f"__val = {dim_index}",
outputs={
"__val": dace.Memlet(
data=output_node.data,
subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)),
)
},
input_nodes={},
output_nodes={output_node.data: output_node},
external_edges=True,
)
state.add_edge(
index_write_tasklet,
"__val",
index_node,
None,
dace.Memlet(data=index_data, subset="0"),
)

return FieldopData(output_node, field_type, field_offset)
input_edges = [
gtir_dataflow.EmptyInputEdge(state, index_write_tasklet),
]
output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value)
return _create_field_operator(
sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge
)


def _get_data_nodes(
Expand Down Expand Up @@ -831,7 +778,6 @@ def translate_symbol_ref(
# Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol
__primitive_translators: list[PrimitiveTranslator] = [
translate_as_fieldop,
translate_broadcast_scalar,
translate_if,
translate_index,
translate_literal,
Expand Down

0 comments on commit 8860584

Please sign in to comment.