diff --git a/numba_cuda/numba/cuda/core/caching.py b/numba_cuda/numba/cuda/core/caching.py index 5bdf70d2b..ed3ea1eb7 100644 --- a/numba_cuda/numba/cuda/core/caching.py +++ b/numba_cuda/numba/cuda/core/caching.py @@ -14,7 +14,7 @@ import tempfile import sys -from numba.misc.appdirs import AppDirs +from numba.cuda.misc.appdirs import AppDirs from pathlib import Path from numba.cuda.core import config diff --git a/numba_cuda/numba/cuda/core/codegen.py b/numba_cuda/numba/cuda/core/codegen.py index 025b93625..b96210af6 100644 --- a/numba_cuda/numba/cuda/core/codegen.py +++ b/numba_cuda/numba/cuda/core/codegen.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-2-Clause from abc import abstractmethod, ABCMeta -from numba.misc.llvm_pass_timings import PassTimingsCollection +from numba.cuda.misc.llvm_pass_timings import PassTimingsCollection class CodeLibrary(metaclass=ABCMeta): diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index 215396198..8005b0d39 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -41,7 +41,7 @@ from numba.np.unsafe.ndarray import empty_inferred as unsafe_empty_inferred import numpy as np import operator -import numba.misc.special +from numba.cuda.misc.special import prange """ Variable enable_inline_arraycall is only used for testing purpose. @@ -1054,10 +1054,7 @@ def _find_iter_range(func_ir, range_iter_var, swapped): debug_print("func_var = ", func_var, " func_def = ", func_def) require( isinstance(func_def, ir.Global) - and ( - func_def.value is range - or func_def.value == numba.misc.special.prange - ) + and (func_def.value is range or func_def.value == prange) ) nargs = len(range_def.args) swapping = [('"array comprehension"', "closure of"), range_def.func.loc] diff --git a/numba_cuda/numba/cuda/core/typeinfer.py b/numba_cuda/numba/cuda/core/typeinfer.py index f35b75598..38aea47e8 100644 --- a/numba_cuda/numba/cuda/core/typeinfer.py +++ b/numba_cuda/numba/cuda/core/typeinfer.py @@ -1692,7 +1692,7 @@ def typeof_global(self, inst, target, gvar): # as a global variable typ = types.Dispatcher(_temporary_dispatcher_map[gvar.name]) else: - from numba.misc import special + from numba.cuda.misc import special nm = gvar.name # check if the problem is actually a name error diff --git a/numba_cuda/numba/cuda/core/untyped_passes.py b/numba_cuda/numba/cuda/core/untyped_passes.py index 53fe7d1fb..e4dc35ff6 100644 --- a/numba_cuda/numba/cuda/core/untyped_passes.py +++ b/numba_cuda/numba/cuda/core/untyped_passes.py @@ -22,7 +22,7 @@ from numba.cuda.core.interpreter import Interpreter -from numba.misc.special import literal_unroll +from numba.cuda.misc.special import literal_unroll from numba.cuda.core.analysis import dead_branch_prune from numba.core.analysis import ( rewrite_semantic_constants, diff --git a/numba_cuda/numba/cuda/lowering.py b/numba_cuda/numba/cuda/lowering.py index e70bd38d2..7edbcbafe 100644 --- a/numba_cuda/numba/cuda/lowering.py +++ b/numba_cuda/numba/cuda/lowering.py @@ -29,13 +29,13 @@ from numba.cuda.core.funcdesc import default_mangler from numba.cuda.core.environment import Environment from numba.core.analysis import compute_use_defs, must_use_alloca -from numba.misc.firstlinefinder import get_func_body_first_lineno +from numba.cuda.misc.firstlinefinder import get_func_body_first_lineno from numba import version_info numba_version = version_info.short del version_info if numba_version > (0, 60): - from numba.misc.coverage_support import get_registered_loc_notify + from numba.cuda.misc.coverage_support import get_registered_loc_notify _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) diff --git a/numba_cuda/numba/cuda/misc/appdirs.py b/numba_cuda/numba/cuda/misc/appdirs.py new file mode 100644 index 000000000..78a660702 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/appdirs.py @@ -0,0 +1,594 @@ +# SPDX-FileCopyrightText: Copyright (c) 2005-2010 ActiveState Software Inc. +# SPDX-FileCopyrightText: Copyright (c) 2013 Eddy Petrișor +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: The MIT License (MIT) + +"""Utilities for determining application-specific dirs. + +See for details and usage. +""" +# Dev Notes: +# - MSDN on where to store app data files: +# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120 +# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html +# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html + +__version_info__ = (1, 4, 1) +__version__ = ".".join(map(str, __version_info__)) + + +import sys +import os + +unicode = str + +if sys.platform.startswith("java"): + import platform + + os_name = platform.java_ver()[3][0] + if os_name.startswith("Windows"): # "Windows XP", "Windows 7", etc. + system = "win32" + elif os_name.startswith("Mac"): # "Mac OS X", etc. + system = "darwin" + else: # "Linux", "SunOS", "FreeBSD", etc. + # Setting this to "linux2" is not ideal, but only Windows or Mac + # are actually checked for and the rest of the module expects + # *sys.platform* style strings. + system = "linux2" +else: + system = sys.platform + + +def user_data_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user data directories are: + Mac OS X: ~/Library/Application Support/ + Unix: ~/.local/share/ # or in $XDG_DATA_HOME, if defined + Win XP (not roaming): C:\Documents and Settings\\Application Data\\ + Win XP (roaming): C:\Documents and Settings\\Local Settings\Application Data\\ + Win 7 (not roaming): C:\Users\\AppData\Local\\ + Win 7 (roaming): C:\Users\\AppData\Roaming\\ + + For Unix, we follow the XDG spec and support $XDG_DATA_HOME. + That means, by default "~/.local/share/". + """ + if system == "win32": + if appauthor is None: + appauthor = appname + const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA" + path = os.path.normpath(_get_win_folder(const)) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("~/Library/Application Support/") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_DATA_HOME", os.path.expanduser("~/.local/share")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_data_dir(appname=None, appauthor=None, version=None, multipath=False): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of data dirs should be + returned. By default, the first item from XDG_DATA_DIRS is + returned, or '/usr/local/share/', + if XDG_DATA_DIRS is not set + + Typical user data directories are: + Mac OS X: /Library/Application Support/ + Unix: /usr/local/share/ or /usr/share/ + Win XP: C:\Documents and Settings\All Users\Application Data\\ + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + Win 7: C:\ProgramData\\ # Hidden, but writeable on Win 7. + + For Unix, this is using the $XDG_DATA_DIRS[0] default. + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + elif system == "darwin": + path = os.path.expanduser("/Library/Application Support") + if appname: + path = os.path.join(path, appname) + else: + # XDG default for $XDG_DATA_DIRS + # only first, if multipath is False + path = os.getenv( + "XDG_DATA_DIRS", os.pathsep.join(["/usr/local/share", "/usr/share"]) + ) + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + if appname and version: + path = os.path.join(path, version) + return path + + +def user_config_dir(appname=None, appauthor=None, version=None, roaming=False): + r"""Return full path to the user-specific config dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "roaming" (boolean, default False) can be set True to use the Windows + roaming appdata directory. That means that for users on a Windows + network setup for roaming profiles, this user data will be + sync'd on login. See + + for a discussion of issues. + + Typical user data directories are: + Mac OS X: same as user_data_dir + Unix: ~/.config/ # or in $XDG_CONFIG_HOME, if defined + Win *: same as user_data_dir + + For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME. + That means, by default "~/.config/". + """ + if system in ["win32", "darwin"]: + path = user_data_dir(appname, appauthor, None, roaming) + else: + path = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def site_config_dir( + appname=None, appauthor=None, version=None, multipath=False +): + r"""Return full path to the user-shared data dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "multipath" is an optional parameter only applicable to *nix + which indicates that the entire list of config dirs should be + returned. By default, the first item from XDG_CONFIG_DIRS is + returned, or '/etc/xdg/', if XDG_CONFIG_DIRS is not set + + Typical user data directories are: + Mac OS X: same as site_data_dir + Unix: /etc/xdg/ or $XDG_CONFIG_DIRS[i]/ for each value in + $XDG_CONFIG_DIRS + Win *: same as site_data_dir + Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.) + + For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False + + WARNING: Do not use this on Windows. See the Vista-Fail note above for why. + """ + if system in ["win32", "darwin"]: + path = site_data_dir(appname, appauthor) + if appname and version: + path = os.path.join(path, version) + else: + # XDG default for $XDG_CONFIG_DIRS + # only first, if multipath is False + path = os.getenv("XDG_CONFIG_DIRS", "/etc/xdg") + pathlist = [ + os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep) + ] + if appname: + if version: + appname = os.path.join(appname, version) + pathlist = [os.sep.join([x, appname]) for x in pathlist] + + if multipath: + path = os.pathsep.join(pathlist) + else: + path = pathlist[0] + return path + + +def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific cache dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Cache" to the base app data dir for Windows. See + discussion below. + + Typical user cache directories are: + Mac OS X: ~/Library/Caches/ + Unix: ~/.cache/ (XDG default) + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache + Vista: C:\Users\\AppData\Local\\\Cache + + On Windows the only suggestion in the MSDN docs is that local settings go in + the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming + app data dir (the default returned by `user_data_dir` above). Apps typically + put cache data somewhere *under* the given dir here. Some examples: + ...\Mozilla\Firefox\Profiles\\Cache + ...\Acme\SuperApp\Cache\1.0 + OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. + This can be disabled with the `opinion=False` option. + """ + if system == "win32": + if appauthor is None: + appauthor = appname + path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) + if appname: + if appauthor is not False: + path = os.path.join(path, appauthor, appname) + else: + path = os.path.join(path, appname) + if opinion: + path = os.path.join(path, "Cache") + elif system == "darwin": + path = os.path.expanduser("~/Library/Caches") + if appname: + path = os.path.join(path, appname) + else: + path = os.getenv("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + if appname: + path = os.path.join(path, appname) + if appname and version: + path = os.path.join(path, version) + return path + + +def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): + r"""Return full path to the user-specific log dir for this application. + + "appname" is the name of application. + If None, just the system directory is returned. + "appauthor" (only used on Windows) is the name of the + appauthor or distributing body for this application. Typically + it is the owning company name. This falls back to appname. You may + pass False to disable it. + "version" is an optional version path element to append to the + path. You might want to use this if you want multiple versions + of your app to be able to run independently. If used, this + would typically be ".". + Only applied when appname is present. + "opinion" (boolean) can be False to disable the appending of + "Logs" to the base app data dir for Windows, and "log" to the + base cache dir for Unix. See discussion below. + + Typical user cache directories are: + Mac OS X: ~/Library/Logs/ + Unix: ~/.cache//log # or under $XDG_CACHE_HOME if defined + Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Logs + Vista: C:\Users\\AppData\Local\\\Logs + + On Windows the only suggestion in the MSDN docs is that local settings + go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in + examples of what some windows apps use for a logs dir.) + + OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA` + value for Windows and appends "log" to the user cache dir for Unix. + This can be disabled with the `opinion=False` option. + """ + if system == "darwin": + path = os.path.join(os.path.expanduser("~/Library/Logs"), appname) + elif system == "win32": + path = user_data_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "Logs") + else: + path = user_cache_dir(appname, appauthor, version) + version = False + if opinion: + path = os.path.join(path, "log") + if appname and version: + path = os.path.join(path, version) + return path + + +class AppDirs(object): + """Convenience wrapper for getting application dirs.""" + + def __init__( + self, + appname, + appauthor=None, + version=None, + roaming=False, + multipath=False, + ): + self.appname = appname + self.appauthor = appauthor + self.version = version + self.roaming = roaming + self.multipath = multipath + + @property + def user_data_dir(self): + return user_data_dir( + self.appname, + self.appauthor, + version=self.version, + roaming=self.roaming, + ) + + @property + def site_data_dir(self): + return site_data_dir( + self.appname, + self.appauthor, + version=self.version, + multipath=self.multipath, + ) + + @property + def user_config_dir(self): + return user_config_dir( + self.appname, + self.appauthor, + version=self.version, + roaming=self.roaming, + ) + + @property + def site_config_dir(self): + return site_config_dir( + self.appname, + self.appauthor, + version=self.version, + multipath=self.multipath, + ) + + @property + def user_cache_dir(self): + return user_cache_dir( + self.appname, self.appauthor, version=self.version + ) + + @property + def user_log_dir(self): + return user_log_dir(self.appname, self.appauthor, version=self.version) + + +# ---- internal support stuff + + +def _get_win_folder_from_registry(csidl_name): + """This is a fallback technique at best. I'm not sure if using the + registry for this guarantees us the correct answer for all CSIDL_* + names. + """ + import winreg as _winreg + + shell_folder_name = { + "CSIDL_APPDATA": "AppData", + "CSIDL_COMMON_APPDATA": "Common AppData", + "CSIDL_LOCAL_APPDATA": "Local AppData", + }[csidl_name] + + key = _winreg.OpenKey( + _winreg.HKEY_CURRENT_USER, + r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", + ) + dir, type = _winreg.QueryValueEx(key, shell_folder_name) + return dir + + +def _get_win_folder_with_pywin32(csidl_name): + dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) + # Try to make this a unicode path because SHGetFolderPath does + # not return unicode strings when there is unicode data in the + # path. + try: + dir = unicode(dir) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + try: + import win32api + + dir = win32api.GetShortPathName(dir) + except ImportError: + pass + except UnicodeError: + pass + return dir + + +def _get_win_folder_with_ctypes(csidl_name): + csidl_const = { + "CSIDL_APPDATA": 26, + "CSIDL_COMMON_APPDATA": 35, + "CSIDL_LOCAL_APPDATA": 28, + }[csidl_name] + + buf = ctypes.create_unicode_buffer(1024) + windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in buf: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf2 = ctypes.create_unicode_buffer(1024) + if windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): + buf = buf2 + + return buf.value + + +def _get_win_folder_with_jna(csidl_name): + import array + + buf_size = win32.WinDef.MAX_PATH * 2 + buf = array.zeros("c", buf_size) + shell = win32.Shell32.INSTANCE + shell.SHGetFolderPath( + None, + getattr(win32.ShlObj, csidl_name), + None, + win32.ShlObj.SHGFP_TYPE_CURRENT, + buf, + ) + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + # Downgrade to short path name if have highbit chars. See + # . + has_high_char = False + for c in dir: + if ord(c) > 255: + has_high_char = True + break + if has_high_char: + buf = array.zeros("c", buf_size) + kernel = win32.Kernel32.INSTANCE + if kernel.GetShortPathName(dir, buf, buf_size): + dir = jna.Native.toString(buf.tostring()).rstrip("\0") + + return dir + + +if system == "win32": + try: + from win32com.shell import shellcon, shell + + _get_win_folder = _get_win_folder_with_pywin32 + except ImportError: + try: + import ctypes + from ctypes import windll + + _get_win_folder = _get_win_folder_with_ctypes + except ImportError: + try: + from com.sun import jna + from com.sun.jna.platform import win32 + + _get_win_folder = _get_win_folder_with_jna + except ImportError: + _get_win_folder = _get_win_folder_from_registry + + +# ---- self test code + +if __name__ == "__main__": + appname = "MyApp" + appauthor = "MyCompany" + + props = ( + "user_data_dir", + "site_data_dir", + "user_config_dir", + "site_config_dir", + "user_cache_dir", + "user_log_dir", + ) + + print("-- app dirs %s --" % __version__) + + print("-- app dirs (with optional 'version')") + dirs = AppDirs(appname, appauthor, version="1.0") + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (without optional 'version')") + dirs = AppDirs(appname, appauthor) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (without optional 'appauthor')") + dirs = AppDirs(appname) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) + + print("\n-- app dirs (with disabled 'appauthor')") + dirs = AppDirs(appname, appauthor=False) + for prop in props: + print("%s: %s" % (prop, getattr(dirs, prop))) diff --git a/numba_cuda/numba/cuda/misc/cffiimpl.py b/numba_cuda/numba/cuda/misc/cffiimpl.py new file mode 100644 index 000000000..4bb0c4b21 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/cffiimpl.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of some CFFI functions +""" + +from numba.core.imputils import Registry +from numba.core import types +from numba.np import arrayobj + +registry = Registry("cffiimpl") + + +@registry.lower("ffi.from_buffer", types.Buffer) +def from_buffer(context, builder, sig, args): + assert len(sig.args) == 1 + assert len(args) == 1 + [fromty] = sig.args + [val] = args + # Type inference should have prevented passing a buffer from an + # array to a pointer of the wrong type + assert fromty.dtype == sig.return_type.dtype + ary = arrayobj.make_array(fromty)(context, builder, val) + return ary.data diff --git a/numba_cuda/numba/cuda/misc/coverage_support.py b/numba_cuda/numba/cuda/misc/coverage_support.py new file mode 100644 index 000000000..c857039db --- /dev/null +++ b/numba_cuda/numba/cuda/misc/coverage_support.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implement code coverage support. + +Currently contains logic to extend ``coverage`` with lines covered by the +compiler. +""" + +from typing import Optional, Sequence, Callable +from abc import ABC, abstractmethod + +from numba.core import ir, config + +_the_registry: Callable[[], Optional["NotifyLocBase"]] = [] + + +def get_registered_loc_notify() -> Sequence["NotifyLocBase"]: + """ + Returns a list of the registered NotifyLocBase instances. + """ + if not config.JIT_COVERAGE: + # Coverage disabled. + return [] + return list( + filter( + lambda x: x is not None, (factory() for factory in _the_registry) + ) + ) + + +class NotifyLocBase(ABC): + """Interface for notifying visiting of a ``numba.core.ir.Loc``.""" + + @abstractmethod + def notify(self, loc: ir.Loc) -> None: + pass + + @abstractmethod + def close(self) -> None: + pass diff --git a/numba_cuda/numba/cuda/misc/dump_style.py b/numba_cuda/numba/cuda/misc/dump_style.py new file mode 100644 index 000000000..0a074468e --- /dev/null +++ b/numba_cuda/numba/cuda/misc/dump_style.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +try: + from pygments.styles.default import DefaultStyle +except ImportError: + msg = "Please install pygments to see highlighted dumps" + raise ImportError(msg) + +import numba.core.config +from pygments.styles.manni import ManniStyle +from pygments.styles.monokai import MonokaiStyle +from pygments.styles.native import NativeStyle + +from pygments.token import Name + +from pygments.style import Style + + +def by_colorscheme(): + """ + Get appropriate style for highlighting according to + NUMBA_COLOR_SCHEME setting + """ + styles = DefaultStyle.styles.copy() + styles.update( + { + Name.Variable: "#888888", + } + ) + custom_default = type("CustomDefaultStyle", (Style,), {"styles": styles}) + + style_map = { + "no_color": custom_default, + "dark_bg": MonokaiStyle, + "light_bg": ManniStyle, + "blue_bg": NativeStyle, + "jupyter_nb": DefaultStyle, + } + + return style_map[numba.core.config.COLOR_SCHEME] diff --git a/numba_cuda/numba/cuda/misc/firstlinefinder.py b/numba_cuda/numba/cuda/misc/firstlinefinder.py new file mode 100644 index 000000000..5f3a9bb09 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/firstlinefinder.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +This module provides helper functions to find the first line of a function +body. +""" + +import ast + + +class FindDefFirstLine(ast.NodeVisitor): + """ + Attributes + ---------- + first_stmt_line : int or None + This stores the first statement line number if the definition is found. + Or, ``None`` if the definition is not found. + """ + + def __init__(self, code): + """ + Parameters + ---------- + code : + The function's code object. + """ + self._co_name = code.co_name + self._co_firstlineno = code.co_firstlineno + self.first_stmt_line = None + + def _visit_children(self, node): + for child in ast.iter_child_nodes(node): + super().visit(child) + + def visit_FunctionDef(self, node: ast.FunctionDef): + if node.name == self._co_name: + # Name of function matches. + + # The `def` line may match co_firstlineno. + possible_start_lines = set([node.lineno]) + if node.decorator_list: + # Has decorators. + # The first decorator line may match co_firstlineno. + first_decor = node.decorator_list[0] + possible_start_lines.add(first_decor.lineno) + # Does the first lineno match? + if self._co_firstlineno in possible_start_lines: + # Yes, we found the function. + # So, use the first statement line as the first line. + if node.body: + first_stmt = node.body[0] + if _is_docstring(first_stmt): + # Skip docstring + first_stmt = node.body[1] + self.first_stmt_line = first_stmt.lineno + return + else: + # This is probably unreachable. + # Function body cannot be bare. It must at least have + # A const string for docstring or a `pass`. + pass + self._visit_children(node) + + +def _is_docstring(node): + if isinstance(node, ast.Expr): + if isinstance(node.value, ast.Constant) and isinstance( + node.value.value, str + ): + return True + return False + + +def get_func_body_first_lineno(pyfunc): + """ + Look up the first line of function body using the file in + ``pyfunc.__code__.co_filename``. + + Returns + ------- + lineno : int; or None + The first line number of the function body; or ``None`` if the first + line cannot be determined. + """ + co = pyfunc.__code__ + try: + with open(co.co_filename) as fin: + file_content = fin.read() + except (FileNotFoundError, OSError): + return + else: + tree = ast.parse(file_content) + finder = FindDefFirstLine(co) + finder.visit(tree) + return finder.first_stmt_line diff --git a/numba_cuda/numba/cuda/misc/gdb_hook.py b/numba_cuda/numba/cuda/misc/gdb_hook.py new file mode 100644 index 000000000..9c169b1a2 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/gdb_hook.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import os +import sys + +from llvmlite import ir + +from numba.core import types, config, errors +from numba.cuda import cgutils, utils +from numba.cuda.misc.special import gdb, gdb_init, gdb_breakpoint +from numba.core.extending import overload, intrinsic + +_path = os.path.dirname(__file__) + +_platform = sys.platform +_unix_like = ( + _platform.startswith("linux") + or _platform.startswith("darwin") + or ("bsd" in _platform) +) + + +def _confirm_gdb(need_ptrace_attach=True): + """ + Set need_ptrace_attach to True/False to indicate whether the ptrace attach + permission is needed for this gdb use case. Mode 0 (classic) or 1 + (restricted ptrace) is required if need_ptrace_attach is True. See: + https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst + for details on the modes. + """ + if not _unix_like: + msg = "gdb support is only available on unix-like systems" + raise errors.NumbaRuntimeError(msg) + gdbloc = config.GDB_BINARY + if not (os.path.exists(gdbloc) and os.path.isfile(gdbloc)): + msg = ( + "Is gdb present? Location specified (%s) does not exist. The gdb" + " binary location can be set using Numba configuration, see: " + "https://numba.readthedocs.io/en/stable/reference/envvars.html" # noqa: E501 + ) + raise RuntimeError(msg % config.GDB_BINARY) + # Is Yama being used as a kernel security module and if so is ptrace_scope + # limited? In this case ptracing non-child processes requires special + # permission so raise an exception. + ptrace_scope_file = os.path.join( + os.sep, "proc", "sys", "kernel", "yama", "ptrace_scope" + ) + has_ptrace_scope = os.path.exists(ptrace_scope_file) + if has_ptrace_scope: + with open(ptrace_scope_file, "rt") as f: + value = f.readline().strip() + if need_ptrace_attach and value not in ("0", "1"): + msg = ( + "gdb can launch but cannot attach to the executing program" + " because ptrace permissions have been restricted at the " + "system level by the Linux security module 'Yama'.\n\n" + "Documentation for this module and the security " + "implications of making changes to its behaviour can be " + "found in the Linux Kernel documentation " + "https://www.kernel.org/doc/Documentation/admin-guide/LSM/Yama.rst" # noqa: E501 + "\n\nDocumentation on how to adjust the behaviour of Yama " + "on Ubuntu Linux with regards to 'ptrace_scope' can be " + "found here " + "https://wiki.ubuntu.com/Security/Features#ptrace." + ) + raise RuntimeError(msg) + + +@overload(gdb, target="cuda") +def hook_gdb(*args): + _confirm_gdb() + gdbimpl = gen_gdb_impl(args, True) + + def impl(*args): + gdbimpl() + + return impl + + +@overload(gdb_init, target="cuda") +def hook_gdb_init(*args): + _confirm_gdb() + gdbimpl = gen_gdb_impl(args, False) + + def impl(*args): + gdbimpl() + + return impl + + +def init_gdb_codegen( + cgctx, builder, signature, args, const_args, do_break=False +): + int8_t = ir.IntType(8) + int32_t = ir.IntType(32) + intp_t = ir.IntType(utils.MACHINE_BITS) + char_ptr = ir.PointerType(ir.IntType(8)) + zero_i32t = int32_t(0) + + mod = builder.module + pid = cgutils.alloca_once(builder, int32_t, size=1) + + # 32bit pid, 11 char max + terminator + pidstr = cgutils.alloca_once(builder, int8_t, size=12) + + # str consts + intfmt = cgctx.insert_const_string(mod, "%d") + gdb_str = cgctx.insert_const_string(mod, config.GDB_BINARY) + attach_str = cgctx.insert_const_string(mod, "attach") + + new_args = [] + # add break point command to known location + # this command file thing is due to commands attached to a breakpoint + # requiring an interactive prompt + # https://sourceware.org/bugzilla/show_bug.cgi?id=10079 + new_args.extend(["-x", os.path.join(_path, "cmdlang.gdb")]) + # issue command to continue execution from sleep function + new_args.extend(["-ex", "c"]) + # then run the user defined args if any + if any([not isinstance(x, types.StringLiteral) for x in const_args]): + raise errors.RequireLiteralValue(const_args) + new_args.extend([x.literal_value for x in const_args]) + cmdlang = [cgctx.insert_const_string(mod, x) for x in new_args] + + # insert getpid, getpid is always successful, call without concern! + fnty = ir.FunctionType(int32_t, tuple()) + getpid = cgutils.get_or_insert_function(mod, fnty, "getpid") + + # insert snprintf + # int snprintf(char *str, size_t size, const char *format, ...); + fnty = ir.FunctionType(int32_t, (char_ptr, intp_t, char_ptr), var_arg=True) + snprintf = cgutils.get_or_insert_function(mod, fnty, "snprintf") + + # insert fork + fnty = ir.FunctionType(int32_t, tuple()) + fork = cgutils.get_or_insert_function(mod, fnty, "fork") + + # insert execl + fnty = ir.FunctionType(int32_t, (char_ptr, char_ptr), var_arg=True) + execl = cgutils.get_or_insert_function(mod, fnty, "execl") + + # insert sleep + fnty = ir.FunctionType(int32_t, (int32_t,)) + sleep = cgutils.get_or_insert_function(mod, fnty, "sleep") + + # insert break point + fnty = ir.FunctionType(ir.VoidType(), tuple()) + breakpoint = cgutils.get_or_insert_function( + mod, fnty, "numba_gdb_breakpoint" + ) + + # do the work + parent_pid = builder.call(getpid, tuple()) + builder.store(parent_pid, pid) + pidstr_ptr = builder.gep(pidstr, [zero_i32t], inbounds=True) + pid_val = builder.load(pid) + + # call snprintf to write the pid into a char * + stat = builder.call(snprintf, (pidstr_ptr, intp_t(12), intfmt, pid_val)) + invalid_write = builder.icmp_signed(">", stat, int32_t(12)) + with builder.if_then(invalid_write, likely=False): + msg = "Internal error: `snprintf` buffer would have overflowed." + cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,)) + + # fork, check pids etc + child_pid = builder.call(fork, tuple()) + fork_failed = builder.icmp_signed("==", child_pid, int32_t(-1)) + with builder.if_then(fork_failed, likely=False): + msg = "Internal error: `fork` failed." + cgctx.call_conv.return_user_exc(builder, RuntimeError, (msg,)) + + is_child = builder.icmp_signed("==", child_pid, zero_i32t) + with builder.if_else(is_child) as (then, orelse): + with then: + # is child + nullptr = ir.Constant(char_ptr, None) + gdb_str_ptr = builder.gep(gdb_str, [zero_i32t], inbounds=True) + attach_str_ptr = builder.gep(attach_str, [zero_i32t], inbounds=True) + cgutils.printf(builder, "Attaching to PID: %s\n", pidstr) + buf = (gdb_str_ptr, gdb_str_ptr, attach_str_ptr, pidstr_ptr) + buf = buf + tuple(cmdlang) + (nullptr,) + builder.call(execl, buf) + with orelse: + # is parent + builder.call(sleep, (int32_t(10),)) + # if breaking is desired, break now + if do_break is True: + builder.call(breakpoint, tuple()) + + +def gen_gdb_impl(const_args, do_break): + @intrinsic + def gdb_internal(tyctx): + function_sig = types.void() + + def codegen(cgctx, builder, signature, args): + init_gdb_codegen( + cgctx, builder, signature, args, const_args, do_break=do_break + ) + return cgctx.get_constant(types.none, None) + + return function_sig, codegen + + return gdb_internal + + +@overload(gdb_breakpoint, target="cuda") +def hook_gdb_breakpoint(): + """ + Adds the Numba break point into the source + """ + if not sys.platform.startswith("linux"): + raise RuntimeError("gdb is only available on linux") + bp_impl = gen_bp_impl() + + def impl(): + bp_impl() + + return impl + + +def gen_bp_impl(): + @intrinsic + def bp_internal(tyctx): + function_sig = types.void() + + def codegen(cgctx, builder, signature, args): + mod = builder.module + fnty = ir.FunctionType(ir.VoidType(), tuple()) + breakpoint = cgutils.get_or_insert_function( + mod, fnty, "numba_gdb_breakpoint" + ) + builder.call(breakpoint, tuple()) + return cgctx.get_constant(types.none, None) + + return function_sig, codegen + + return bp_internal diff --git a/numba_cuda/numba/cuda/misc/literal.py b/numba_cuda/numba/cuda/misc/literal.py new file mode 100644 index 000000000..7847b765d --- /dev/null +++ b/numba_cuda/numba/cuda/misc/literal.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from numba.core.extending import overload +from numba.core import types +from numba.cuda.misc.special import literally, literal_unroll +from numba.core.errors import TypingError + + +@overload(literally, target="cuda") +def _ov_literally(obj): + if isinstance(obj, (types.Literal, types.InitialValue)): + return lambda obj: obj + else: + m = "Invalid use of non-Literal type in literally({})".format(obj) + raise TypingError(m) + + +@overload(literal_unroll, target="cuda") +def literal_unroll_impl(container): + if isinstance(container, types.Poison): + m = f"Invalid use of non-Literal type in literal_unroll({container})" + raise TypingError(m) + + def impl(container): + return container + + return impl diff --git a/numba_cuda/numba/cuda/misc/llvm_pass_timings.py b/numba_cuda/numba/cuda/misc/llvm_pass_timings.py new file mode 100644 index 000000000..77d9ef3d7 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/llvm_pass_timings.py @@ -0,0 +1,412 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import re +import operator +import heapq +from collections import namedtuple +from collections.abc import Sequence +from contextlib import contextmanager +from functools import cached_property + +from numba.core import config + +import llvmlite.binding as llvm + + +class RecordLLVMPassTimings: + """A helper context manager to track LLVM pass timings.""" + + __slots__ = ["_data"] + + def __enter__(self): + """Enables the pass timing in LLVM.""" + llvm.set_time_passes(True) + return self + + def __exit__(self, exc_val, exc_type, exc_tb): + """Reset timings and save report internally.""" + self._data = llvm.report_and_reset_timings() + llvm.set_time_passes(False) + return + + def get(self): + """Retrieve timing data for processing. + + Returns + ------- + timings: ProcessedPassTimings + """ + return ProcessedPassTimings(self._data) + + +PassTimingRecord = namedtuple( + "PassTimingRecord", + [ + "user_time", + "user_percent", + "system_time", + "system_percent", + "user_system_time", + "user_system_percent", + "wall_time", + "wall_percent", + "pass_name", + "instruction", + ], +) + + +def _adjust_timings(records): + """Adjust timing records because of truncated information. + + Details: The percent information can be used to improve the timing + information. + + Returns + ------- + res: List[PassTimingRecord] + """ + total_rec = records[-1] + assert total_rec.pass_name == "Total" # guard for implementation error + + def make_adjuster(attr): + time_attr = f"{attr}_time" + percent_attr = f"{attr}_percent" + time_getter = operator.attrgetter(time_attr) + + def adjust(d): + """Compute percent x total_time = adjusted""" + total = time_getter(total_rec) + adjusted = total * d[percent_attr] * 0.01 + d[time_attr] = adjusted + return d + + return adjust + + # Make adjustment functions for each field + adj_fns = [ + make_adjuster(x) for x in ["user", "system", "user_system", "wall"] + ] + + # Extract dictionaries from the namedtuples + dicts = map(lambda x: x._asdict(), records) + + def chained(d): + # Chain the adjustment functions + for fn in adj_fns: + d = fn(d) + # Reconstruct the namedtuple + return PassTimingRecord(**d) + + return list(map(chained, dicts)) + + +class ProcessedPassTimings: + """A class for processing raw timing report from LLVM. + + The processing is done lazily so we don't waste time processing unused + timing information. + """ + + def __init__(self, raw_data): + self._raw_data = raw_data + + def __bool__(self): + return bool(self._raw_data) + + def get_raw_data(self): + """Returns the raw string data. + + Returns + ------- + res: str + """ + return self._raw_data + + def get_total_time(self): + """Compute the total time spend in all passes. + + Returns + ------- + res: float + """ + return self.list_records()[-1].wall_time + + def list_records(self): + """Get the processed data for the timing report. + + Returns + ------- + res: List[PassTimingRecord] + """ + return self._processed + + def list_top(self, n): + """Returns the top(n) most time-consuming (by wall-time) passes. + + Parameters + ---------- + n: int + This limits the maximum number of items to show. + This function will show the ``n`` most time-consuming passes. + + Returns + ------- + res: List[PassTimingRecord] + Returns the top(n) most time-consuming passes in descending order. + """ + records = self.list_records() + key = operator.attrgetter("wall_time") + return heapq.nlargest(n, records[:-1], key) + + def summary(self, topn=5, indent=0): + """Return a string summarizing the timing information. + + Parameters + ---------- + topn: int; optional + This limits the maximum number of items to show. + This function will show the ``topn`` most time-consuming passes. + indent: int; optional + Set the indentation level. Defaults to 0 for no indentation. + + Returns + ------- + res: str + """ + buf = [] + prefix = " " * indent + + def ap(arg): + buf.append(f"{prefix}{arg}") + + ap(f"Total {self.get_total_time():.4f}s") + ap("Top timings:") + for p in self.list_top(topn): + ap(f" {p.wall_time:.4f}s ({p.wall_percent:5}%) {p.pass_name}") + return "\n".join(buf) + + @cached_property + def _processed(self): + """A cached property for lazily processing the data and returning it. + + See ``_process()`` for details. + """ + return self._process() + + def _process(self): + """Parses the raw string data from LLVM timing report and attempts + to improve the data by recomputing the times + (See `_adjust_timings()``). + """ + + def parse(raw_data): + """A generator that parses the raw_data line-by-line to extract + timing information for each pass. + """ + lines = raw_data.splitlines() + colheader = r"[a-zA-Z+ ]+" + # Take at least one column header. + multicolheaders = rf"(?:\s*-+{colheader}-+)+" + + line_iter = iter(lines) + # find column headers + header_map = { + "User Time": "user", + "System Time": "system", + "User+System": "user_system", + "Wall Time": "wall", + "Instr": "instruction", + "Name": "pass_name", + } + for ln in line_iter: + m = re.match(multicolheaders, ln) + if m: + # Get all the column headers + raw_headers = re.findall(r"[a-zA-Z][a-zA-Z+ ]+", ln) + headers = [header_map[k.strip()] for k in raw_headers] + break + + assert headers[-1] == "pass_name" + # compute the list of available attributes from the column headers + attrs = [] + n = r"\s*((?:[0-9]+\.)?[0-9]+)" + pat = "" + for k in headers[:-1]: + if k == "instruction": + pat += n + else: + attrs.append(f"{k}_time") + attrs.append(f"{k}_percent") + pat += rf"\s+(?:{n}\s*\({n}%\)|-+)" + + # put default value 0.0 to all missing attributes + missing = {} + for k in PassTimingRecord._fields: + if k not in attrs and k != "pass_name": + missing[k] = 0.0 + # parse timings + pat += r"\s*(.*)" + for ln in line_iter: + m = re.match(pat, ln) + if m is not None: + raw_data = list(m.groups()) + data = { + k: float(v) if v is not None else 0.0 + for k, v in zip(attrs, raw_data) + } + data.update(missing) + pass_name = raw_data[-1] + rec = PassTimingRecord( + pass_name=pass_name, + **data, + ) + yield rec + if rec.pass_name == "Total": + # "Total" means the report has ended + break + # Check that we have reach the end of the report + remaining = "\n".join(line_iter) + if remaining: + raise ValueError( + f"unexpected text after parser finished:\n{remaining}" + ) + + # Parse raw data + records = list(parse(self._raw_data)) + return _adjust_timings(records) + + +NamedTimings = namedtuple("NamedTimings", ["name", "timings"]) + + +class PassTimingsCollection(Sequence): + """A collection of pass timings. + + This class implements the ``Sequence`` protocol for accessing the + individual timing records. + """ + + def __init__(self, name): + self._name = name + self._records = [] + + @contextmanager + def record(self, name): + """Record new timings and append to this collection. + + Note: this is mainly for internal use inside the compiler pipeline. + + See also ``RecordLLVMPassTimings`` + + Parameters + ---------- + name: str + Name for the records. + """ + if config.LLVM_PASS_TIMINGS: + # Recording of pass timings is enabled + with RecordLLVMPassTimings() as timings: + yield + rec = timings.get() + # Only keep non-empty records + if rec: + self._append(name, rec) + else: + # Do nothing. Recording of pass timings is disabled. + yield + + def _append(self, name, timings): + """Append timing records + + Parameters + ---------- + name: str + Name for the records. + timings: ProcessedPassTimings + the timing records. + """ + self._records.append(NamedTimings(name, timings)) + + def get_total_time(self): + """Computes the sum of the total time across all contained timings. + + Returns + ------- + res: float or None + Returns the total number of seconds or None if no timings were + recorded + """ + if self._records: + return sum(r.timings.get_total_time() for r in self._records) + else: + return None + + def list_longest_first(self): + """Returns the timings in descending order of total time duration. + + Returns + ------- + res: List[ProcessedPassTimings] + """ + return sorted( + self._records, + key=lambda x: x.timings.get_total_time(), + reverse=True, + ) + + @property + def is_empty(self): + """ """ + return not self._records + + def summary(self, topn=5): + """Return a string representing the summary of the timings. + + Parameters + ---------- + topn: int; optional, default=5. + This limits the maximum number of items to show. + This function will show the ``topn`` most time-consuming passes. + + Returns + ------- + res: str + + See also ``ProcessedPassTimings.summary()`` + """ + if self.is_empty: + return "No pass timings were recorded" + else: + buf = [] + ap = buf.append + ap(f"Printing pass timings for {self._name}") + overall_time = self.get_total_time() + ap(f"Total time: {overall_time:.4f}") + for i, r in enumerate(self._records): + ap(f"== #{i} {r.name}") + percent = r.timings.get_total_time() / overall_time * 100 + ap(f" Percent: {percent:.1f}%") + ap(r.timings.summary(topn=topn, indent=1)) + return "\n".join(buf) + + def __getitem__(self, i): + """Get the i-th timing record. + + Returns + ------- + res: (name, timings) + A named tuple with two fields: + + - name: str + - timings: ProcessedPassTimings + """ + return self._records[i] + + def __len__(self): + """Length of this collection.""" + return len(self._records) + + def __str__(self): + return self.summary() diff --git a/numba_cuda/numba/cuda/misc/special.py b/numba_cuda/numba/cuda/misc/special.py new file mode 100644 index 000000000..2bf1a2342 --- /dev/null +++ b/numba_cuda/numba/cuda/misc/special.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + + +class prange(object): + """Provides a 1D parallel iterator that generates a sequence of integers. + In non-parallel contexts, prange is identical to range. + """ + + def __new__(cls, *args): + return range(*args) + + +def _gdb_python_call_gen(func_name, *args): + # generates a call to a function containing a compiled in gdb command, + # this is to make `numba.gdb*` work in the interpreter. + import numba + + fn = getattr(numba, func_name) + argstr = ",".join(['"%s"' for _ in args]) % args + defn = """def _gdb_func_injection():\n\t%s(%s)\n + """ % (func_name, argstr) + l = {} + exec(defn, {func_name: fn}, l) + return numba.njit(l["_gdb_func_injection"]) + + +def gdb(*args): + """ + Calling this function will invoke gdb and attach it to the current process + at the call site. Arguments are strings in the gdb command language syntax + which will be executed by gdb once initialisation has occurred. + """ + _gdb_python_call_gen("gdb", *args)() + + +def gdb_breakpoint(): + """ + Calling this function will inject a breakpoint at the call site that is + recognised by both `gdb` and `gdb_init`, this is to allow breaking at + multiple points. gdb will stop in the user defined code just after the frame + employed by the breakpoint returns. + """ + _gdb_python_call_gen("gdb_breakpoint")() + + +def gdb_init(*args): + """ + Calling this function will invoke gdb and attach it to the current process + at the call site, then continue executing the process under gdb's control. + Arguments are strings in the gdb command language syntax which will be + executed by gdb once initialisation has occurred. + """ + _gdb_python_call_gen("gdb_init", *args)() + + +def literally(obj): + """Forces Numba to interpret *obj* as an Literal value. + + *obj* must be either a literal or an argument of the caller function, where + the argument must be bound to a literal. The literal requirement + propagates up the call stack. + + This function is intercepted by the compiler to alter the compilation + behavior to wrap the corresponding function parameters as ``Literal``. + It has **no effect** outside of nopython-mode (interpreter, and objectmode). + + The current implementation detects literal arguments in two ways: + + 1. Scans for uses of ``literally`` via a compiler pass. + 2. ``literally`` is overloaded to raise ``numba.errors.ForceLiteralArg`` + to signal the dispatcher to treat the corresponding parameter + differently. This mode is to support indirect use (via a function call). + + The execution semantic of this function is equivalent to an identity + function. + + See :ghfile:`numba/tests/test_literal_dispatch.py` for examples. + """ + return obj + + +def literal_unroll(container): + return container + + +__all__ = [ + "prange", + "gdb", + "gdb_breakpoint", + "gdb_init", + "literally", + "literal_unroll", +] diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 0fe359d98..201e4beb6 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -168,7 +168,7 @@ def load_additional_registries(self): ) from numba.cpython import rangeobj, enumimpl # noqa: F401 from numba.cuda.core import optional # noqa: F401 - from numba.misc import cffiimpl + from numba.cuda.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 from . import ( diff --git a/numba_cuda/numba/cuda/tests/nocuda/test_import.py b/numba_cuda/numba/cuda/tests/nocuda/test_import.py index e8dce925f..c38c4a9ce 100644 --- a/numba_cuda/numba/cuda/tests/nocuda/test_import.py +++ b/numba_cuda/numba/cuda/tests/nocuda/test_import.py @@ -35,9 +35,9 @@ def test_no_impl_import(self): "numba.cuda.cpython.unicode", "numba.cuda.cpython.charseq", "numba.cuda.core.optional", - "numba.misc.gdb_hook", - "numba.misc.literal", - "numba.misc.cffiimpl", + "numba.cuda.misc.gdb_hook", + "numba.cuda.misc.literal", + "numba.cuda.misc.cffiimpl", "numba.np.linalg", "numba.np.polynomial", "numba.np.arraymath", diff --git a/numba_cuda/numba/cuda/utils.py b/numba_cuda/numba/cuda/utils.py index 4e5484482..76cdd1af1 100644 --- a/numba_cuda/numba/cuda/utils.py +++ b/numba_cuda/numba/cuda/utils.py @@ -653,7 +653,7 @@ def dump_llvm(fndesc, module): from pygments import highlight from pygments.lexers import LlvmLexer as lexer from pygments.formatters import Terminal256Formatter - from numba.misc.dump_style import by_colorscheme + from numba.cuda.misc.dump_style import by_colorscheme print( highlight(