Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions numba_cuda/numba/cuda/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
funcdesc,
config,
bytecode,
postproc,
cpu,
)
from numba.core.compiler_lock import global_compiler_lock
Expand All @@ -22,7 +21,7 @@
from numba.cuda import cgutils, typing, lowering, nvvmutils, utils
from numba.cuda.api import get_current_device
from numba.cuda.codegen import ExternalCodeLibrary
from numba.cuda.core import sigutils
from numba.cuda.core import sigutils, postproc
from numba.cuda.cudadrv import nvvm, nvrtc
from numba.cuda.descriptor import cuda_target
from numba.cuda.flags import CUDAFlags
Expand Down
4 changes: 2 additions & 2 deletions numba_cuda/numba/cuda/core/compiler_machinery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from numba.core import errors, config, transforms
from numba.cuda import utils
from numba.core.tracing import event
from numba.core.postproc import PostProcessor
from numba.cuda.core.postproc import PostProcessor
from numba.cuda.core.ir_utils import enforce_no_dels, legalize_single_scope
import numba.core.event as ev

import numba.core.compiler_machinery as nccm
import numba.cuda.core.compiler_machinery as nccm

# terminal color markup
_termcolor = errors.termcolor()
Expand Down
3 changes: 2 additions & 1 deletion numba_cuda/numba/cuda/core/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import numba
from numba.core.extending import _Intrinsic
from numba.core import types, typing, ir, analysis, postproc, rewrites, config
from numba.core import types, typing, ir, analysis, rewrites, config
from numba.cuda.core import postproc
from numba.core.typing.templates import signature
from numba.core.analysis import (
compute_live_map,
Expand Down
249 changes: 249 additions & 0 deletions numba_cuda/numba/cuda/core/postproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-2-Clause
from functools import cached_property
from numba.core import ir, transforms, analysis
from numba.cuda.core import ir_utils


class YieldPoint(object):
def __init__(self, block, inst):
assert isinstance(block, ir.Block)
assert isinstance(inst, ir.Yield)
self.block = block
self.inst = inst
self.live_vars = None
self.weak_live_vars = None


class GeneratorInfo(object):
def __init__(self):
# { index: YieldPoint }
self.yield_points = {}
# Ordered list of variable names
self.state_vars = []

def get_yield_points(self):
"""
Return an iterable of YieldPoint instances.
"""
return self.yield_points.values()


class VariableLifetime(object):
"""
For lazily building information of variable lifetime
"""

def __init__(self, blocks):
self._blocks = blocks

@cached_property
def cfg(self):
return analysis.compute_cfg_from_blocks(self._blocks)

@cached_property
def usedefs(self):
return analysis.compute_use_defs(self._blocks)

@cached_property
def livemap(self):
return analysis.compute_live_map(
self.cfg, self._blocks, self.usedefs.usemap, self.usedefs.defmap
)

@cached_property
def deadmaps(self):
return analysis.compute_dead_maps(
self.cfg, self._blocks, self.livemap, self.usedefs.defmap
)


# other packages that define new nodes add calls for inserting dels
# format: {type:function}
ir_extension_insert_dels = {}


class PostProcessor(object):
"""
A post-processor for Numba IR.
"""

def __init__(self, func_ir):
self.func_ir = func_ir

def run(self, emit_dels: bool = False, extend_lifetimes: bool = False):
"""
Run the following passes over Numba IR:
- canonicalize the CFG
- emit explicit `del` instructions for variables
- compute lifetime of variables
- compute generator info (if function is a generator function)
"""
self.func_ir.blocks = transforms.canonicalize_cfg(self.func_ir.blocks)
vlt = VariableLifetime(self.func_ir.blocks)
self.func_ir.variable_lifetime = vlt

bev = analysis.compute_live_variables(
vlt.cfg,
self.func_ir.blocks,
vlt.usedefs.defmap,
vlt.deadmaps.combined,
)
for offset, ir_block in self.func_ir.blocks.items():
self.func_ir.block_entry_vars[ir_block] = bev[offset]

if self.func_ir.is_generator:
self.func_ir.generator_info = GeneratorInfo()
self._compute_generator_info()
else:
self.func_ir.generator_info = None

# Emit del nodes, do this last as the generator info parsing generates
# and then strips dels as part of its analysis.
if emit_dels:
self._insert_var_dels(extend_lifetimes=extend_lifetimes)

def _populate_generator_info(self):
"""
Fill `index` for the Yield instruction and create YieldPoints.
"""
dct = self.func_ir.generator_info.yield_points
assert not dct, "rerunning _populate_generator_info"
for block in self.func_ir.blocks.values():
for inst in block.body:
if isinstance(inst, ir.Assign):
yieldinst = inst.value
if isinstance(yieldinst, ir.Yield):
index = len(dct) + 1
yieldinst.index = index
yp = YieldPoint(block, yieldinst)
dct[yieldinst.index] = yp

def _compute_generator_info(self):
"""
Compute the generator's state variables as the union of live variables
at all yield points.
"""
# generate del info, it's used in analysis here, strip it out at the end
self._insert_var_dels()
self._populate_generator_info()
gi = self.func_ir.generator_info
for yp in gi.get_yield_points():
live_vars = set(self.func_ir.get_block_entry_vars(yp.block))
weak_live_vars = set()
stmts = iter(yp.block.body)
for stmt in stmts:
if isinstance(stmt, ir.Assign):
if stmt.value is yp.inst:
break
live_vars.add(stmt.target.name)
elif isinstance(stmt, ir.Del):
live_vars.remove(stmt.value)
else:
assert 0, "couldn't find yield point"
# Try to optimize out any live vars that are deleted immediately
# after the yield point.
for stmt in stmts:
if isinstance(stmt, ir.Del):
name = stmt.value
if name in live_vars:
live_vars.remove(name)
weak_live_vars.add(name)
else:
break
yp.live_vars = live_vars
yp.weak_live_vars = weak_live_vars

st = set()
for yp in gi.get_yield_points():
st |= yp.live_vars
st |= yp.weak_live_vars
gi.state_vars = sorted(st)
self.remove_dels()

def _insert_var_dels(self, extend_lifetimes=False):
"""
Insert del statements for each variable.
Returns a 2-tuple of (variable definition map, variable deletion map)
which indicates variables defined and deleted in each block.

The algorithm avoids relying on explicit knowledge on loops and
distinguish between variables that are defined locally vs variables that
come from incoming blocks.
We start with simple usage (variable reference) and definition (variable
creation) maps on each block. Propagate the liveness info to predecessor
blocks until it stabilize, at which point we know which variables must
exist before entering each block. Then, we compute the end of variable
lives and insert del statements accordingly. Variables are deleted after
the last use. Variable referenced by terminators (e.g. conditional
branch and return) are deleted by the successors or the caller.
"""
vlt = self.func_ir.variable_lifetime
self._patch_var_dels(
vlt.deadmaps.internal,
vlt.deadmaps.escaping,
extend_lifetimes=extend_lifetimes,
)

def _patch_var_dels(
self, internal_dead_map, escaping_dead_map, extend_lifetimes=False
):
"""
Insert delete in each block
"""
for offset, ir_block in self.func_ir.blocks.items():
# for each internal var, insert delete after the last use
internal_dead_set = internal_dead_map[offset].copy()
delete_pts = []
# for each statement in reverse order
for stmt in reversed(ir_block.body[:-1]):
# internal vars that are used here
live_set = set(v.name for v in stmt.list_vars())
dead_set = live_set & internal_dead_set
for T, def_func in ir_extension_insert_dels.items():
if isinstance(stmt, T):
done_dels = def_func(stmt, dead_set)
dead_set -= done_dels
internal_dead_set -= done_dels
# used here but not afterwards
delete_pts.append((stmt, dead_set))
internal_dead_set -= dead_set

# rewrite body and insert dels
body = []
lastloc = ir_block.loc
del_store = []
for stmt, delete_set in reversed(delete_pts):
# If using extended lifetimes then the Dels are all put at the
# block end just ahead of the terminator, so associate their
# location with the terminator.
if extend_lifetimes:
lastloc = ir_block.body[-1].loc
else:
lastloc = stmt.loc
# Ignore dels (assuming no user inserted deletes)
if not isinstance(stmt, ir.Del):
body.append(stmt)
# note: the reverse sort is not necessary for correctness
# it is just to minimize changes to test for now
for var_name in sorted(delete_set, reverse=True):
delnode = ir.Del(var_name, loc=lastloc)
if extend_lifetimes:
del_store.append(delnode)
else:
body.append(delnode)
if extend_lifetimes:
body.extend(del_store)
body.append(ir_block.body[-1]) # terminator
ir_block.body = body

# vars to delete at the start
escape_dead_set = escaping_dead_map[offset]
for var_name in sorted(escape_dead_set):
ir_block.prepend(ir.Del(var_name, loc=ir_block.body[0].loc))

def remove_dels(self):
"""
Strips the IR of Del nodes
"""
ir_utils.remove_dels(self.func_ir.blocks)
2 changes: 1 addition & 1 deletion numba_cuda/numba/cuda/core/typed_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
is_operator_or_getitem,
replace_vars,
)
from numba.core import postproc
from numba.cuda.core import postproc
from llvmlite import binding as llvm


Expand Down
3 changes: 1 addition & 2 deletions numba_cuda/numba/cuda/core/untyped_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
SSACompliantMixin,
register_pass,
)
from numba.cuda.core import postproc, bytecode
from numba.core import (
errors,
types,
ir,
bytecode,
postproc,
rewrites,
config,
transforms,
Expand Down
Loading