diff --git a/numba_cuda/numba/cuda/printimpl.py b/numba_cuda/numba/cuda/printimpl.py index c953286cf..c034c9edb 100644 --- a/numba_cuda/numba/cuda/printimpl.py +++ b/numba_cuda/numba/cuda/printimpl.py @@ -32,6 +32,20 @@ def print_item(ty, context, builder, val): ) +@print_item.register(types.UniTuple) +def tuple_print_impl(ty, context, builder, val): + if ty.dtype != types.int64: + raise NotImplementedError( + "printing unimplemented for tuples with elements of type %s" % (ty,) + ) + + nelements = val.type.count + argsfmt = ", ".join(["%lld"] * nelements) + rawfmt = f"({argsfmt})" + values = [builder.extract_value(val, i) for i in range(nelements)] + return rawfmt, values + + @print_item.register(types.Integer) @print_item.register(types.IntegerLiteral) def int_print_impl(ty, context, builder, val): diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_print.py b/numba_cuda/numba/cuda/tests/cudapy/test_print.py index beac03d34..e96f3ad30 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_print.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_print.py @@ -117,6 +117,17 @@ def print_bfloat16(): cuda.synchronize() """ +print_int64_tuple_usecase = """\ +from numba import cuda + +@cuda.jit +def print_tuple(tup): + print(tup) + +print_tuple[1, 1]((1, 2, 3, 4, 5)) +cuda.synchronize() +""" + class TestPrint(CUDATestCase): # Note that in these tests we generally strip the output to avoid dealing @@ -163,6 +174,12 @@ def test_dim3(self): expected = [str(i) for i in np.ndindex(2, 2, 2)] self.assertEqual(sorted(lines), expected) + def test_tuple(self): + output, _ = self.run_code(print_int64_tuple_usecase) + lines = [line.strip() for line in output.splitlines(True)] + expected = ["(1, 2, 3, 4, 5)"] + self.assertEqual(lines, expected) + @skip_on_cudasim("bfloat16 on host is not yet supported.") def test_bfloat16(self): output, _ = self.run_code(print_bfloat16_usecase)