diff --git a/numba_cuda/numba/cuda/core/compiler.py b/numba_cuda/numba/cuda/core/compiler.py index 0700cf97c..8f502e926 100644 --- a/numba_cuda/numba/cuda/core/compiler.py +++ b/numba_cuda/numba/cuda/core/compiler.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-2-Clause -from numba.core.tracing import event +from numba.cuda.core.tracing import event from numba.cuda.core import bytecode from numba.core import callconv, config, errors diff --git a/numba_cuda/numba/cuda/core/compiler_machinery.py b/numba_cuda/numba/cuda/core/compiler_machinery.py index f16067c31..ca0fa3d32 100644 --- a/numba_cuda/numba/cuda/core/compiler_machinery.py +++ b/numba_cuda/numba/cuda/core/compiler_machinery.py @@ -10,7 +10,7 @@ from numba.core.compiler_lock import global_compiler_lock from numba.core import errors, config, transforms from numba.cuda import utils -from numba.core.tracing import event +from numba.cuda.core.tracing import event from numba.cuda.core.postproc import PostProcessor from numba.cuda.core.ir_utils import enforce_no_dels, legalize_single_scope import numba.core.event as ev diff --git a/numba_cuda/numba/cuda/core/tracing.py b/numba_cuda/numba/cuda/core/tracing.py new file mode 100644 index 000000000..8736535ed --- /dev/null +++ b/numba_cuda/numba/cuda/core/tracing.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import inspect +import logging +import sys +import threading +from functools import wraps +from itertools import chain + +from numba.core import config + + +class TLS(threading.local): + """Use a subclass to properly initialize the TLS variables in all threads.""" # noqa: E501 + + def __init__(self): + self.tracing = False + self.indent = 0 + + +tls = TLS() + + +def find_function_info(func, spec, args): + """Return function meta-data in a tuple. + + (name, type)""" + + module = getattr(func, "__module__", None) + name = getattr(func, "__name__", None) + self = getattr(func, "__self__", None) + cname = None + if self: + cname = self.__name__ + # cname = self.__class__.__name__ + # Try to deduce the class' name even for unbound methods from their + # first argument, which we assume to be a class instance if named 'self'... + elif len(spec.args) and spec.args[0] == "self": + cname = args[0].__class__.__name__ + # ...or a class object if named 'cls' + elif len(spec.args) and spec.args[0] == "cls": + cname = args[0].__name__ + if name: + qname = [] + if module and module != "__main__": + qname.append(module) + qname.append(".") + if cname: + qname.append(cname) + qname.append(".") + qname.append(name) + name = "".join(qname) + return name, None + + +def chop(value): + MAX_SIZE = 320 + s = repr(value) + if len(s) > MAX_SIZE: + return s[:MAX_SIZE] + "..." + s[-1] + else: + return s + + +def create_events(fname, spec, args, kwds): + values = dict() + if spec.defaults: + values = dict(zip(spec.args[-len(spec.defaults) :], spec.defaults)) + values.update(kwds) + values.update(list(zip(spec.args[: len(args)], args))) + positional = ["%s=%r" % (a, values.pop(a)) for a in spec.args] + anonymous = [str(a) for a in args[len(positional) :]] + keywords = ["%s=%r" % (k, values[k]) for k in sorted(values.keys())] + params = ", ".join([f for f in chain(positional, anonymous, keywords) if f]) + + enter = [">> ", tls.indent * " ", fname, "(", params, ")"] + leave = ["<< ", tls.indent * " ", fname] + return enter, leave + + +def dotrace(*args, **kwds): + """Function decorator to trace a function's entry and exit. + + *args: categories in which to trace this function. Example usage: + + @trace + def function(...):... + + @trace('mycategory') + def function(...):... + + + """ + + recursive = kwds.get("recursive", False) + + def decorator(func): + spec = None + logger = logging.getLogger("trace") + + def wrapper(*args, **kwds): + if not logger.isEnabledFor(logging.INFO) or tls.tracing: + return func(*args, **kwds) + + fname, ftype = find_function_info(func, spec, args) + + try: + tls.tracing = True + enter, leave = create_events(fname, spec, args, kwds) + + try: + logger.info("".join(enter)) + tls.indent += 1 + try: + try: + tls.tracing = False + result = func(*args, **kwds) + finally: + tls.tracing = True + except: # noqa: E722 + type, value, traceback = sys.exc_info() + leave.append(" => exception thrown\n\traise ") + mname = type.__module__ + if mname != "__main__": + leave.append(mname) + leave.append(".") + leave.append(type.__name__) + if value.args: + leave.append("(") + leave.append(", ".join(chop(v) for v in value.args)) + leave.append(")") + else: + leave.append("()") + raise + else: + if result is not None: + leave.append(" -> ") + leave.append(chop(result)) + finally: + tls.indent -= 1 + logger.info("".join(leave)) + finally: + tls.tracing = False + return result + + # wrapper end + + rewrap = lambda x: x + # Unwrap already wrapped functions + # (to be rewrapped again later) + if isinstance(func, classmethod): + rewrap = type(func) + # Note: 'func.__func__' only works in Python 3 + func = func.__get__(True).__func__ + elif isinstance(func, staticmethod): + rewrap = type(func) + # Note: 'func.__func__' only works in Python 3 + func = func.__get__(True) + elif isinstance(func, property): + raise NotImplementedError + + spec = inspect.getfullargspec(func) + return rewrap(wraps(func)(wrapper)) + + arg0 = len(args) and args[0] or None + # not supported yet... + if recursive: + raise NotImplementedError + if inspect.ismodule(arg0): + for n, f in inspect.getmembers(arg0, inspect.isfunction): + setattr(arg0, n, decorator(f)) + for n, c in inspect.getmembers(arg0, inspect.isclass): + dotrace(c, *args, recursive=recursive) + elif inspect.isclass(arg0): + for n, f in inspect.getmembers( + arg0, lambda x: (inspect.isfunction(x) or inspect.ismethod(x)) + ): + setattr(arg0, n, decorator(f)) + + if callable(arg0) or type(arg0) in (classmethod, staticmethod): + return decorator(arg0) + elif isinstance(arg0, property): + # properties combine up to three functions: 'get', 'set', 'del', + # so let's wrap them all. + pget, pset, pdel = None, None, None + if arg0.fget: + pget = decorator(arg0.fget) + if arg0.fset: + pset = decorator(arg0.fset) + if arg0.fdel: + pdel = decorator(arg0.fdel) + return property(pget, pset, pdel) + + else: + return decorator + + +def notrace(*args, **kwds): + """Just a no-op in case tracing is disabled.""" + + def decorator(func): + return func + + arg0 = len(args) and args[0] or None + + if callable(arg0) or type(arg0) in (classmethod, staticmethod): + return decorator(arg0) + else: + return decorator + + +def doevent(msg): + msg = ["== ", tls.indent * " ", msg] + logger = logging.getLogger("trace") + logger.info("".join(msg)) + + +def noevent(msg): + pass + + +if config.TRACE: + logger = logging.getLogger("trace") + logger.setLevel(logging.INFO) + logger.handlers = [logging.StreamHandler()] + trace = dotrace + event = doevent +else: + trace = notrace + event = noevent diff --git a/numba_cuda/numba/cuda/tests/test_tracing.py b/numba_cuda/numba/cuda/tests/test_tracing.py new file mode 100644 index 000000000..d2234f150 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/test_tracing.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +from io import StringIO +import logging + +import unittest +from numba.cuda.core import tracing + +logger = logging.getLogger("trace") +logger.setLevel(logging.INFO) + +# Make sure tracing is enabled +orig_trace = tracing.trace +tracing.trace = tracing.dotrace + + +class CapturedTrace: + """Capture the trace temporarily for validation.""" + + def __init__(self): + self.buffer = StringIO() + self.handler = logging.StreamHandler(self.buffer) + + def __enter__(self): + self._handlers = logger.handlers + self.buffer = StringIO() + logger.handlers = [logging.StreamHandler(self.buffer)] + + def __exit__(self, type, value, traceback): + logger.handlers = self._handlers + + def getvalue(self): + # Depending on how the tests are run, object names may be + # qualified by their containing module. + # Remove that to make the trace output independent from the testing mode. + log = self.buffer.getvalue() + log = log.replace(__name__ + ".", "") + return log + + +class Class(object): + @tracing.trace + @classmethod + def class_method(cls): + pass + + @tracing.trace + @staticmethod + def static_method(): + pass + + __test = None + + def _test_get(self): + return self.__test + + def _test_set(self, value): + self.__test = value + + test = tracing.trace(property(_test_get, _test_set)) + + @tracing.trace + def method(self, some, other="value", *args, **kwds): + pass + + def __repr__(self): + """Generate a deterministic string for testing.""" + return "" + + +class Class2(object): + @classmethod + def class_method(cls): + pass + + @staticmethod + def static_method(): + pass + + __test = None + + @property + def test(self): + return self.__test + + @test.setter + def test(self, value): + self.__test = value + + def method(self): + pass + + def __str__(self): + return "Test(" + str(self.test) + ")" + + def __repr__(self): + """Generate a deterministic string for testing.""" + return "" + + +@tracing.trace +def test_traced_function(): + # Test the tracing functionality with fixed values + x, y = 5, 5 + z = True + + a = x + y + b = x * y + if z: + result = a + else: + result = b + + # The function should return 10 (5 + 5) when z is True + assert result == 10 + + +class TestTracing(unittest.TestCase): + def __init__(self, *args): + super(TestTracing, self).__init__(*args) + + def setUp(self): + self.capture = CapturedTrace() + + def tearDown(self): + del self.capture + + def test_method(self): + with self.capture: + Class().method("foo", bar="baz") + self.assertEqual( + self.capture.getvalue(), + ">> Class.method(self=, some='foo', other='value', bar='baz')\n" + + "<< Class.method\n", + ) + + def test_class_method(self): + with self.capture: + Class.class_method() + self.assertEqual( + self.capture.getvalue(), + ">> Class.class_method(cls=)\n" + + "<< Class.class_method\n", + ) + + def test_static_method(self): + with self.capture: + Class.static_method() + self.assertEqual( + self.capture.getvalue(), + ">> static_method()\n" + "<< static_method\n", + ) + + def test_property(self): + with self.capture: + test = Class() + test.test = 1 + assert 1 == test.test + self.assertEqual( + self.capture.getvalue(), + ">> Class._test_set(self=, value=1)\n" + + "<< Class._test_set\n" + + ">> Class._test_get(self=)\n" + + "<< Class._test_get -> 1\n", + ) + + def test_function(self): + with self.capture: + test_traced_function() + # The test function should be traced when called + trace_output = self.capture.getvalue() + self.assertIn(">> test_traced_function()", trace_output) + self.assertIn("<< test_traced_function", trace_output) + + @unittest.skip("recursive decoration not yet implemented") + def test_injected(self): + with self.capture: + tracing.trace(Class2, recursive=True) + Class2.class_method() + Class2.static_method() + test = Class2() + test.test = 1 + assert 1 == test.test + test.method() + + self.assertEqual( + self.capture.getvalue(), + ">> Class2.class_method(cls=)\n" + + "<< Class2.class_method\n" + ">> static_method()\n" + "<< static_method\n", + ) + + +# Reset tracing to its original value +tracing.trace = orig_trace + +if __name__ == "__main__": + unittest.main()