diff --git a/numba_cuda/numba/cuda/printimpl.py b/numba_cuda/numba/cuda/printimpl.py index c034c9edb..9141b7113 100644 --- a/numba_cuda/numba/cuda/printimpl.py +++ b/numba_cuda/numba/cuda/printimpl.py @@ -32,17 +32,23 @@ def print_item(ty, context, builder, val): ) +@print_item.register(types.Tuple) @print_item.register(types.UniTuple) def tuple_print_impl(ty, context, builder, val): - if ty.dtype != types.int64: - raise NotImplementedError( - "printing unimplemented for tuples with elements of type %s" % (ty,) - ) + formats = [] + values = [] - nelements = val.type.count - argsfmt = ", ".join(["%lld"] * nelements) - rawfmt = f"({argsfmt})" - values = [builder.extract_value(val, i) for i in range(nelements)] + for i, argtyp in enumerate(ty.types): + argval = builder.extract_value(val, i) + argfmt, argvals = print_item(argtyp, context, builder, argval) + formats.append(argfmt) + values.extend(argvals) + + if len(formats) == 1: + base = "({},)" + else: + base = "({})" + rawfmt = base.format(", ".join(formats)) return rawfmt, values diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_print.py b/numba_cuda/numba/cuda/tests/cudapy/test_print.py index e96f3ad30..696c377d2 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_print.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_print.py @@ -128,6 +128,28 @@ def print_tuple(tup): cuda.synchronize() """ +print_nested_mixed_type_tuple_usecase = """\ +from numba import cuda + +@cuda.jit +def print_tuple(tup): + print(tup) + +print_tuple[1, 1]((1, ((2, 4), 3.0), (4,), 5)) +cuda.synchronize() +""" + +print_single_element_tuple_usecase = """\ +from numba import cuda + +@cuda.jit +def print_tuple(tup): + print(tup) + +print_tuple[1, 1]((1,)) +cuda.synchronize() +""" + class TestPrint(CUDATestCase): # Note that in these tests we generally strip the output to avoid dealing @@ -180,6 +202,18 @@ def test_tuple(self): expected = ["(1, 2, 3, 4, 5)"] self.assertEqual(lines, expected) + def test_nested_mixed_type_tuple(self): + output, _ = self.run_code(print_nested_mixed_type_tuple_usecase) + (line,) = (line.strip() for line in output.splitlines(True)) + expected = r"^\(1, \(\(2, 4\), 3\.0+\), \(4,\), 5\)$" + self.assertRegex(line, expected) + + def test_single_element_tuple(self): + output, _ = self.run_code(print_single_element_tuple_usecase) + lines = [line.strip() for line in output.splitlines(True)] + expected = ["(1,)"] + self.assertEqual(lines, expected) + @skip_on_cudasim("bfloat16 on host is not yet supported.") def test_bfloat16(self): output, _ = self.run_code(print_bfloat16_usecase)