Skip to content

Commit

Permalink
Python frontend stability and inline storage specification (#1711)
Browse files Browse the repository at this point in the history
The PR adds a new syntax to support inline storage specification with
the `@` operator, supporting the following statements: `a = np.ones(M) @
dace.StorageType.CPU_ThreadLocal`.

This PR also fixes multiple minor issues in the Python frontend:
* `WarpTiling` did not respect sequential map schedules
* Non-sequence inputs for `numpy.fill` variants (e.g.,
`numpy.zeros(N)`)
* NumPy replacement syntax errors would sometimes not have source
information
* Fix type inference for nested scopes in Python frontend
* Dynamic thread block scheduling does not support multi-dimensional
maps
* Default schedule inference should use dynamic thread blocks if
they exist
* Type hints with storage type not being adhered to by the Python
frontend
* Validation issue #1562

The following changes were added as skipped tests and deferred to future PRs:
* Dynamic map range related issues: Fix deferred to #1696
* Dynamic thread block scheduling would not pass object to nested
functions: Fix deferred to future PR, see #1189 for more information
  • Loading branch information
tbennun authored Oct 29, 2024
1 parent d8ddc75 commit 7cb93f2
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 136 deletions.
74 changes: 38 additions & 36 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from dace.codegen.targets.target import IllegalCopy, TargetCodeGenerator, make_absolute
from dace.config import Config
from dace.frontend import operations
from dace.sdfg import (SDFG, ScopeSubgraphView, SDFGState, has_dynamic_map_inputs,
is_array_stream_view, is_devicelevel_gpu, nodes, scope_contains_scope)
from dace.sdfg import (SDFG, ScopeSubgraphView, SDFGState, has_dynamic_map_inputs, is_array_stream_view,
is_devicelevel_gpu, nodes, scope_contains_scope)
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import MultiConnectorEdge
from dace.sdfg.state import ControlFlowRegion, StateSubgraphView
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, frame_codegen: 'DaCeCodeGenerator', sdfg: SDFG):
dispatcher = self._dispatcher

self.create_grid_barrier = False
self.dynamic_tbmap_type = None
self.extra_nsdfg_args = []
CUDACodeGen._in_device_code = False
self._cpu_codegen: Optional['CPUCodeGen'] = None
Expand Down Expand Up @@ -892,8 +893,8 @@ def increment(streams):

return max_streams, max_events

def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.StorageType,
dst_node: nodes.Node, dst_storage: dtypes.StorageType, dst_schedule: dtypes.ScheduleType,
def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.StorageType, dst_node: nodes.Node,
dst_storage: dtypes.StorageType, dst_schedule: dtypes.ScheduleType,
edge: Tuple[nodes.Node, str, nodes.Node, str, Memlet], sdfg: SDFG, cfg: ControlFlowRegion,
dfg: StateSubgraphView, callsite_stream: CodeIOStream) -> None:
u, uconn, v, vconn, memlet = edge
Expand Down Expand Up @@ -1163,11 +1164,8 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St
copysize=', '.join(_topy(copy_shape)),
is_async='true' if state_dfg.out_degree(dst_node) == 0 else 'false',
accum=accum or '::Copy',
args=', '.join(
[src_expr] + _topy(src_strides) + [dst_expr] + _topy(dst_strides) + custom_reduction
)
),
cfg, state_id, [src_node, dst_node])
args=', '.join([src_expr] + _topy(src_strides) + [dst_expr] + _topy(dst_strides) +
custom_reduction)), cfg, state_id, [src_node, dst_node])
else:
callsite_stream.write(
(' {func}<{type}, {bdims}, {copysize}, ' +
Expand Down Expand Up @@ -1236,8 +1234,12 @@ def _begin_streams(self, sdfg, state):
result.add(e.dst._cuda_stream)
return result

def generate_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState,
function_stream: CodeIOStream, callsite_stream: CodeIOStream,
def generate_state(self,
sdfg: SDFG,
cfg: ControlFlowRegion,
state: SDFGState,
function_stream: CodeIOStream,
callsite_stream: CodeIOStream,
generate_state_footer: bool = False) -> None:
# Two modes: device-level state and if this state has active streams
if CUDACodeGen._in_device_code:
Expand Down Expand Up @@ -1361,8 +1363,7 @@ def generate_devicelevel_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state:
"&& threadIdx.x == 0) "
"{ // sub-graph begin", cfg, state.block_id)
elif write_scope == 'block':
callsite_stream.write("if (threadIdx.x == 0) "
"{ // sub-graph begin", cfg, state.block_id)
callsite_stream.write("if (threadIdx.x == 0) " "{ // sub-graph begin", cfg, state.block_id)
else:
callsite_stream.write("{ // subgraph begin", cfg, state.block_id)
else:
Expand Down Expand Up @@ -1985,16 +1986,13 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S

# allocating shared memory for dynamic threadblock maps
if has_dtbmap:
kernel_stream.write(
'__shared__ dace::'
'DynamicMap<{fine_grained}, {block_size}>'
'::shared_type dace_dyn_map_shared;'.format(
fine_grained=('true'
if Config.get_bool('compiler', 'cuda', 'dynamic_map_fine_grained') else 'false'),
block_size=functools.reduce(
(lambda x, y: x * y),
[int(x) for x in Config.get('compiler', 'cuda', 'dynamic_map_block_size').split(',')])), cfg,
state_id, node)
self.dynamic_tbmap_type = (
f'dace::DynamicMap<{"true" if Config.get_bool("compiler", "cuda", "dynamic_map_fine_grained") else "false"}, '
f'{functools.reduce((lambda x, y: x * y), [int(x) for x in Config.get("compiler", "cuda", "dynamic_map_block_size").split(",")])}>'
'::shared_type')
kernel_stream.write(f'__shared__ {self.dynamic_tbmap_type} dace_dyn_map_shared;', cfg, state_id, node)
else:
self.dynamic_tbmap_type = None

# Add extra opening brace (dynamic map ranges, closed in MapExit
# generator)
Expand Down Expand Up @@ -2072,8 +2070,8 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S

# Generate conditions for this block's execution using min and max
# element, e.g., skipping out-of-bounds threads in trailing block
# unless thsi is handled by another map down the line
if (not has_tbmap and not has_dtbmap and node.map.schedule != dtypes.ScheduleType.GPU_Persistent):
# unless this is handled by another map down the line
if ((not has_tbmap or has_dtbmap) and node.map.schedule != dtypes.ScheduleType.GPU_Persistent):
dsym_end = [d + bs - 1 for d, bs in zip(dsym, self._block_dims)]
minels = krange.min_element()
maxels = krange.max_element()
Expand All @@ -2090,10 +2088,12 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S
condition += '%s < %s' % (v, _topy(maxel + 1))
if len(condition) > 0:
self._kernel_grid_conditions.append(f'if ({condition}) {{')
kernel_stream.write('if (%s) {' % condition, cfg, state_id, scope_entry)
if not has_dtbmap:
kernel_stream.write('if (%s) {' % condition, cfg, state_id, scope_entry)
else:
self._kernel_grid_conditions.append('{')
kernel_stream.write('{', cfg, state_id, scope_entry)
if not has_dtbmap:
kernel_stream.write('{', cfg, state_id, scope_entry)

self._dispatcher.dispatch_subgraph(sdfg,
cfg,
Expand All @@ -2112,6 +2112,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S
self._kernel_state = None
CUDACodeGen._in_device_code = False
self._grid_dims = None
self.dynamic_tbmap_type = None

def get_next_scope_entries(self, dfg, scope_entry):
parent_scope_entry = dfg.entry_node(scope_entry)
Expand Down Expand Up @@ -2179,10 +2180,8 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco
current_sdfg = current_state.parent
if not outer_scope:
raise ValueError(f'Failed to find the outer scope of {scope_entry}')
callsite_stream.write(
'if ({} < {}) {{'.format(outer_scope.map.params[0],
_topy(subsets.Range(outer_scope.map.range[::-1]).max_element()[0] + 1)), cfg,
state_id, scope_entry)
for cond in self._kernel_grid_conditions:
callsite_stream.write(cond, cfg, state_id, scope_entry)

# NOTE: Dynamic map inputs must be defined both outside and inside the dynamic Map schedule.
# They define inside the schedule the bounds of the any nested Maps.
Expand All @@ -2205,8 +2204,9 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco
'__dace_dynmap_begin = {begin};\n'
'__dace_dynmap_end = {end};'.format(begin=dynmap_begin, end=dynmap_end), cfg, state_id, scope_entry)

# close if
callsite_stream.write('}', cfg, state_id, scope_entry)
# Close kernel grid conditions
for _ in self._kernel_grid_conditions:
callsite_stream.write('}', cfg, state_id, scope_entry)

callsite_stream.write(
'dace::DynamicMap<{fine_grained}, {bsize}>::'
Expand All @@ -2215,7 +2215,7 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco
'auto {param}) {{'.format(fine_grained=('true' if Config.get_bool(
'compiler', 'cuda', 'dynamic_map_fine_grained') else 'false'),
bsize=total_block_size,
kmapIdx=outer_scope.map.params[0],
kmapIdx=outer_scope.map.params[-1],
param=dynmap_var), cfg, state_id, scope_entry)

for e in dace.sdfg.dynamic_map_inputs(dfg, scope_entry):
Expand Down Expand Up @@ -2556,8 +2556,8 @@ def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_sco
for cond in self._kernel_grid_conditions:
callsite_stream.write(cond, cfg, state_id, scope_entry)

def generate_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int,
node: nodes.Node, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None:
def generate_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Node,
function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None:
if self.node_dispatch_predicate(sdfg, dfg, node):
# Dynamically obtain node generator according to class name
gen = getattr(self, '_generate_' + type(node).__name__, False)
Expand Down Expand Up @@ -2594,6 +2594,8 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node):
result = self._cpu_codegen.generate_nsdfg_arguments(sdfg, cfg, dfg, state, node)
if self.create_grid_barrier:
result.append(('cub::GridBarrier&', '__gbar', '__gbar'))
if self.dynamic_tbmap_type:
result.append((f'{self.dynamic_tbmap_type}&', 'dace_dyn_map_shared', 'dace_dyn_map_shared'))

# Add data from nested SDFGs to kernel arguments
result.extend([(atype, aname, aname) for atype, aname, _ in self.extra_nsdfg_args])
Expand Down
4 changes: 3 additions & 1 deletion dace/codegen/tools/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np
import ast
from dace import dtypes
from dace import data, dtypes
from dace import symbolic
from dace.codegen import cppunparse
from dace.symbolic import symbol, SymExpr, symstr
Expand Down Expand Up @@ -286,6 +286,8 @@ def _Name(t, symbols, inferred_symbols):
inferred_type = dtypes.typeclass(inferred_type.type)
elif isinstance(inferred_type, symbolic.symbol):
inferred_type = inferred_type.dtype
elif isinstance(inferred_type, data.Data):
inferred_type = inferred_type.dtype
elif t_id in inferred_symbols:
inferred_type = inferred_symbols[t_id]
return inferred_type
Expand Down
2 changes: 0 additions & 2 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
""" A module that contains various DaCe type definitions. """
from __future__ import print_function
import ctypes
import aenum
import inspect
import itertools
import numpy
import re
from collections import OrderedDict
Expand Down
31 changes: 22 additions & 9 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,19 +1489,19 @@ def _symbols_from_params(self, params: List[Tuple[str, Union[str, dtypes.typecla
else:
values = str(val).split(':')
if len(values) == 1:
result[name] = symbolic.symbol(name, infer_expr_type(values[0], {**self.globals, **dyn_inputs}))
result[name] = symbolic.symbol(name, infer_expr_type(values[0], {**self.defined, **dyn_inputs}))
elif len(values) == 2:
result[name] = symbolic.symbol(
name,
dtypes.result_type_of(infer_expr_type(values[0], {
**self.globals,
**self.defined,
**dyn_inputs
}), infer_expr_type(values[1], {
**self.globals,
**self.defined,
**dyn_inputs
})))
elif len(values) == 3:
result[name] = symbolic.symbol(name, infer_expr_type(values[0], {**self.globals, **dyn_inputs}))
result[name] = symbolic.symbol(name, infer_expr_type(values[0], {**self.defined, **dyn_inputs}))
else:
raise DaceSyntaxError(
self, None, "Invalid number of arguments in a range iterator. "
Expand Down Expand Up @@ -3258,18 +3258,23 @@ def visit_AnnAssign(self, node: ast.AnnAssign):
dtype = astutils.evalnode(node.annotation, {**self.globals, **self.defined})
if isinstance(dtype, data.Data):
simple_type = dtype.dtype
storage = dtype.storage
else:
simple_type = dtype
storage = dtypes.StorageType.Default
if not isinstance(simple_type, dtypes.typeclass):
raise TypeError
except:
dtype = None
storage = dtypes.StorageType.Default
type_name = rname(node.annotation)
warnings.warn('typeclass {} is not supported'.format(type_name))
if node.value is None and dtype is not None: # Annotating type without assignment
self.annotated_types[rname(node.target)] = dtype
return
self._visit_assign(node, node.target, None, dtype=dtype)
results = self._visit_assign(node, node.target, None, dtype=dtype)
if storage != dtypes.StorageType.Default:
self.sdfg.arrays[results[0][0]].storage = storage

def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
# Get targets (elts) and results
Expand Down Expand Up @@ -3563,6 +3568,8 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
self.cfg_target.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge())
self.last_block = output_indirection

return results

def visit_AugAssign(self, node: ast.AugAssign):
self._visit_assign(node, node.target, augassign_ops[type(node.op).__name__])

Expand Down Expand Up @@ -4623,10 +4630,16 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
self._add_state('call_%d' % node.lineno)
self.last_block.set_default_lineinfo(self.current_lineinfo)

if found_ufunc:
result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords)
else:
result = func(self, self.sdfg, self.last_block, *args, **keywords)
try:
if found_ufunc:
result = func(self, node, self.sdfg, self.last_block, ufunc_name, args, keywords)
else:
result = func(self, self.sdfg, self.last_block, *args, **keywords)
except DaceSyntaxError as ex:
# Attach source information to exception
if ex.node is None:
ex.node = node
raise

self.last_block.set_default_lineinfo(None)

Expand Down
Loading

0 comments on commit 7cb93f2

Please sign in to comment.