diff --git a/numba_cuda/numba/cuda/compiler.py b/numba_cuda/numba/cuda/compiler.py index 5de90d75c..a40ae2e45 100644 --- a/numba_cuda/numba/cuda/compiler.py +++ b/numba_cuda/numba/cuda/compiler.py @@ -12,7 +12,6 @@ funcdesc, config, bytecode, - postproc, cpu, ) from numba.core.compiler_lock import global_compiler_lock @@ -22,7 +21,7 @@ from numba.cuda import cgutils, typing, lowering, nvvmutils, utils from numba.cuda.api import get_current_device from numba.cuda.codegen import ExternalCodeLibrary -from numba.cuda.core import sigutils +from numba.cuda.core import sigutils, postproc from numba.cuda.cudadrv import nvvm, nvrtc from numba.cuda.descriptor import cuda_target from numba.cuda.flags import CUDAFlags diff --git a/numba_cuda/numba/cuda/core/compiler_machinery.py b/numba_cuda/numba/cuda/core/compiler_machinery.py index bf560ebac..f16067c31 100644 --- a/numba_cuda/numba/cuda/core/compiler_machinery.py +++ b/numba_cuda/numba/cuda/core/compiler_machinery.py @@ -11,11 +11,11 @@ from numba.core import errors, config, transforms from numba.cuda import utils from numba.core.tracing import event -from numba.core.postproc import PostProcessor +from numba.cuda.core.postproc import PostProcessor from numba.cuda.core.ir_utils import enforce_no_dels, legalize_single_scope import numba.core.event as ev -import numba.core.compiler_machinery as nccm +import numba.cuda.core.compiler_machinery as nccm # terminal color markup _termcolor = errors.termcolor() diff --git a/numba_cuda/numba/cuda/core/ir_utils.py b/numba_cuda/numba/cuda/core/ir_utils.py index d0eeb6ea7..09b810527 100644 --- a/numba_cuda/numba/cuda/core/ir_utils.py +++ b/numba_cuda/numba/cuda/core/ir_utils.py @@ -11,7 +11,8 @@ import numba from numba.core.extending import _Intrinsic -from numba.core import types, typing, ir, analysis, postproc, rewrites, config +from numba.core import types, typing, ir, analysis, rewrites, config +from numba.cuda.core import postproc from numba.core.typing.templates import signature from numba.core.analysis import ( compute_live_map, diff --git a/numba_cuda/numba/cuda/core/postproc.py b/numba_cuda/numba/cuda/core/postproc.py new file mode 100644 index 000000000..23a62af1c --- /dev/null +++ b/numba_cuda/numba/cuda/core/postproc.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause +from functools import cached_property +from numba.core import ir, transforms, analysis +from numba.cuda.core import ir_utils + + +class YieldPoint(object): + def __init__(self, block, inst): + assert isinstance(block, ir.Block) + assert isinstance(inst, ir.Yield) + self.block = block + self.inst = inst + self.live_vars = None + self.weak_live_vars = None + + +class GeneratorInfo(object): + def __init__(self): + # { index: YieldPoint } + self.yield_points = {} + # Ordered list of variable names + self.state_vars = [] + + def get_yield_points(self): + """ + Return an iterable of YieldPoint instances. + """ + return self.yield_points.values() + + +class VariableLifetime(object): + """ + For lazily building information of variable lifetime + """ + + def __init__(self, blocks): + self._blocks = blocks + + @cached_property + def cfg(self): + return analysis.compute_cfg_from_blocks(self._blocks) + + @cached_property + def usedefs(self): + return analysis.compute_use_defs(self._blocks) + + @cached_property + def livemap(self): + return analysis.compute_live_map( + self.cfg, self._blocks, self.usedefs.usemap, self.usedefs.defmap + ) + + @cached_property + def deadmaps(self): + return analysis.compute_dead_maps( + self.cfg, self._blocks, self.livemap, self.usedefs.defmap + ) + + +# other packages that define new nodes add calls for inserting dels +# format: {type:function} +ir_extension_insert_dels = {} + + +class PostProcessor(object): + """ + A post-processor for Numba IR. + """ + + def __init__(self, func_ir): + self.func_ir = func_ir + + def run(self, emit_dels: bool = False, extend_lifetimes: bool = False): + """ + Run the following passes over Numba IR: + - canonicalize the CFG + - emit explicit `del` instructions for variables + - compute lifetime of variables + - compute generator info (if function is a generator function) + """ + self.func_ir.blocks = transforms.canonicalize_cfg(self.func_ir.blocks) + vlt = VariableLifetime(self.func_ir.blocks) + self.func_ir.variable_lifetime = vlt + + bev = analysis.compute_live_variables( + vlt.cfg, + self.func_ir.blocks, + vlt.usedefs.defmap, + vlt.deadmaps.combined, + ) + for offset, ir_block in self.func_ir.blocks.items(): + self.func_ir.block_entry_vars[ir_block] = bev[offset] + + if self.func_ir.is_generator: + self.func_ir.generator_info = GeneratorInfo() + self._compute_generator_info() + else: + self.func_ir.generator_info = None + + # Emit del nodes, do this last as the generator info parsing generates + # and then strips dels as part of its analysis. + if emit_dels: + self._insert_var_dels(extend_lifetimes=extend_lifetimes) + + def _populate_generator_info(self): + """ + Fill `index` for the Yield instruction and create YieldPoints. + """ + dct = self.func_ir.generator_info.yield_points + assert not dct, "rerunning _populate_generator_info" + for block in self.func_ir.blocks.values(): + for inst in block.body: + if isinstance(inst, ir.Assign): + yieldinst = inst.value + if isinstance(yieldinst, ir.Yield): + index = len(dct) + 1 + yieldinst.index = index + yp = YieldPoint(block, yieldinst) + dct[yieldinst.index] = yp + + def _compute_generator_info(self): + """ + Compute the generator's state variables as the union of live variables + at all yield points. + """ + # generate del info, it's used in analysis here, strip it out at the end + self._insert_var_dels() + self._populate_generator_info() + gi = self.func_ir.generator_info + for yp in gi.get_yield_points(): + live_vars = set(self.func_ir.get_block_entry_vars(yp.block)) + weak_live_vars = set() + stmts = iter(yp.block.body) + for stmt in stmts: + if isinstance(stmt, ir.Assign): + if stmt.value is yp.inst: + break + live_vars.add(stmt.target.name) + elif isinstance(stmt, ir.Del): + live_vars.remove(stmt.value) + else: + assert 0, "couldn't find yield point" + # Try to optimize out any live vars that are deleted immediately + # after the yield point. + for stmt in stmts: + if isinstance(stmt, ir.Del): + name = stmt.value + if name in live_vars: + live_vars.remove(name) + weak_live_vars.add(name) + else: + break + yp.live_vars = live_vars + yp.weak_live_vars = weak_live_vars + + st = set() + for yp in gi.get_yield_points(): + st |= yp.live_vars + st |= yp.weak_live_vars + gi.state_vars = sorted(st) + self.remove_dels() + + def _insert_var_dels(self, extend_lifetimes=False): + """ + Insert del statements for each variable. + Returns a 2-tuple of (variable definition map, variable deletion map) + which indicates variables defined and deleted in each block. + + The algorithm avoids relying on explicit knowledge on loops and + distinguish between variables that are defined locally vs variables that + come from incoming blocks. + We start with simple usage (variable reference) and definition (variable + creation) maps on each block. Propagate the liveness info to predecessor + blocks until it stabilize, at which point we know which variables must + exist before entering each block. Then, we compute the end of variable + lives and insert del statements accordingly. Variables are deleted after + the last use. Variable referenced by terminators (e.g. conditional + branch and return) are deleted by the successors or the caller. + """ + vlt = self.func_ir.variable_lifetime + self._patch_var_dels( + vlt.deadmaps.internal, + vlt.deadmaps.escaping, + extend_lifetimes=extend_lifetimes, + ) + + def _patch_var_dels( + self, internal_dead_map, escaping_dead_map, extend_lifetimes=False + ): + """ + Insert delete in each block + """ + for offset, ir_block in self.func_ir.blocks.items(): + # for each internal var, insert delete after the last use + internal_dead_set = internal_dead_map[offset].copy() + delete_pts = [] + # for each statement in reverse order + for stmt in reversed(ir_block.body[:-1]): + # internal vars that are used here + live_set = set(v.name for v in stmt.list_vars()) + dead_set = live_set & internal_dead_set + for T, def_func in ir_extension_insert_dels.items(): + if isinstance(stmt, T): + done_dels = def_func(stmt, dead_set) + dead_set -= done_dels + internal_dead_set -= done_dels + # used here but not afterwards + delete_pts.append((stmt, dead_set)) + internal_dead_set -= dead_set + + # rewrite body and insert dels + body = [] + lastloc = ir_block.loc + del_store = [] + for stmt, delete_set in reversed(delete_pts): + # If using extended lifetimes then the Dels are all put at the + # block end just ahead of the terminator, so associate their + # location with the terminator. + if extend_lifetimes: + lastloc = ir_block.body[-1].loc + else: + lastloc = stmt.loc + # Ignore dels (assuming no user inserted deletes) + if not isinstance(stmt, ir.Del): + body.append(stmt) + # note: the reverse sort is not necessary for correctness + # it is just to minimize changes to test for now + for var_name in sorted(delete_set, reverse=True): + delnode = ir.Del(var_name, loc=lastloc) + if extend_lifetimes: + del_store.append(delnode) + else: + body.append(delnode) + if extend_lifetimes: + body.extend(del_store) + body.append(ir_block.body[-1]) # terminator + ir_block.body = body + + # vars to delete at the start + escape_dead_set = escaping_dead_map[offset] + for var_name in sorted(escape_dead_set): + ir_block.prepend(ir.Del(var_name, loc=ir_block.body[0].loc)) + + def remove_dels(self): + """ + Strips the IR of Del nodes + """ + ir_utils.remove_dels(self.func_ir.blocks) diff --git a/numba_cuda/numba/cuda/core/typed_passes.py b/numba_cuda/numba/cuda/core/typed_passes.py index 4c1c08a8c..f2b9f4d1b 100644 --- a/numba_cuda/numba/cuda/core/typed_passes.py +++ b/numba_cuda/numba/cuda/core/typed_passes.py @@ -49,7 +49,7 @@ is_operator_or_getitem, replace_vars, ) -from numba.core import postproc +from numba.cuda.core import postproc from llvmlite import binding as llvm diff --git a/numba_cuda/numba/cuda/core/untyped_passes.py b/numba_cuda/numba/cuda/core/untyped_passes.py index f77f70af7..20ee4dfe6 100644 --- a/numba_cuda/numba/cuda/core/untyped_passes.py +++ b/numba_cuda/numba/cuda/core/untyped_passes.py @@ -12,12 +12,11 @@ SSACompliantMixin, register_pass, ) +from numba.cuda.core import postproc, bytecode from numba.core import ( errors, types, ir, - bytecode, - postproc, rewrites, config, transforms,