Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ci/test_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ set +u
conda activate test
set -u

pip install filecheck

rapids-mamba-retry install -c `pwd`/conda-repo numba-cuda

RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"}/
Expand Down
2 changes: 2 additions & 0 deletions ci/test_conda_ctypes_binding.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ set +u
conda activate test
set -u

pip install filecheck

rapids-mamba-retry install -c `pwd`/conda-repo numba-cuda

RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"}/
Expand Down
2 changes: 2 additions & 0 deletions ci/test_simulator.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ set +u
conda activate test
set -u

pip install filecheck

rapids-mamba-retry install -c `pwd`/conda-repo numba-cuda

RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"}/
Expand Down
2 changes: 2 additions & 0 deletions numba_cuda/numba/cuda/simulator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
if config.ENABLE_CUDASIM:
import sys
from numba.cuda.simulator import cudadrv
from . import dispatcher

sys.modules["numba.cuda.cudadrv"] = cudadrv
sys.modules["numba.cuda.cudadrv.devicearray"] = cudadrv.devicearray
Expand All @@ -43,6 +44,7 @@
sys.modules["numba.cuda.cudadrv.drvapi"] = cudadrv.drvapi
sys.modules["numba.cuda.cudadrv.error"] = cudadrv.error
sys.modules["numba.cuda.cudadrv.nvvm"] = cudadrv.nvvm
sys.modules["numba.cuda.dispatcher"] = dispatcher

from . import bf16, compiler, _internal

Expand Down
7 changes: 7 additions & 0 deletions numba_cuda/numba/cuda/simulator/dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class CUDADispatcher:
"""
Dummy class so that consumers that try to import the real CUDADispatcher
do not get an import failure when running with the simulator.
"""

...
104 changes: 103 additions & 1 deletion numba_cuda/numba/cuda/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,121 @@
import platform
import shutil

from numba.core.utils import PYVERSION
from numba.tests.support import SerialMixin
from numba.cuda.cuda_paths import get_conda_ctk
from numba.cuda.cudadrv import driver, devices, libs
from numba.cuda.dispatcher import CUDADispatcher
from numba.core import config
from numba.tests.support import TestCase
from pathlib import Path
from typing import Union
from io import StringIO
import unittest

if PYVERSION >= (3, 10):
from filecheck.matcher import Matcher, Options
from filecheck.parser import Parser, pattern_for_opts
from filecheck.finput import FInput

numba_cuda_dir = Path(__file__).parent
test_data_dir = numba_cuda_dir / "tests" / "data"


class CUDATestCase(SerialMixin, TestCase):
class FileCheckTestCaseMixin:
"""
Mixin for tests that use FileCheck.

Methods assertFileCheckAsm and assertFileCheckLLVM will inspect a
CUDADispatcher and assert that the compilation artifacts match the
FileCheck checks given in the kernel's docstring.

Method assertFileCheckMatches can be used to assert that a given string
matches FileCheck checks, and is not specific to CUDADispatcher.
"""

def assertFileCheckAsm(
self,
ir_producer: CUDADispatcher,
signature: Union[tuple[type, ...], None] = None,
check_prefixes: list[str] = ("ASM",),
**extra_filecheck_options: dict[str, Union[str, int]],
) -> None:
"""
Assert that the assembly output of the given CUDADispatcher matches
the FileCheck checks given in the kernel's docstring.
"""
ir_content = ir_producer.inspect_asm()
if signature:
ir_content = ir_content[signature]
check_patterns = ir_producer.__doc__
self.assertFileCheckMatches(
ir_content,
check_patterns=check_patterns,
check_prefixes=check_prefixes,
**extra_filecheck_options,
)

def assertFileCheckLLVM(
self,
ir_producer: CUDADispatcher,
signature: Union[tuple[type, ...], None] = None,
check_prefixes: list[str] = ("LLVM",),
**extra_filecheck_options: dict[str, Union[str, int]],
) -> None:
"""
Assert that the LLVM IR output of the given CUDADispatcher matches
the FileCheck checks given in the kernel's docstring.
"""
ir_content = ir_producer.inspect_llvm()
if signature:
ir_content = ir_content[signature]
check_patterns = ir_producer.__doc__
self.assertFileCheckMatches(
ir_content,
check_patterns=check_patterns,
check_prefixes=check_prefixes,
**extra_filecheck_options,
)

def assertFileCheckMatches(
self,
ir_content: str,
check_patterns: str,
check_prefixes: list[str] = ("CHECK",),
**extra_filecheck_options: dict[str, Union[str, int]],
) -> None:
"""
Assert that the given string matches the passed FileCheck checks.

Args:
ir_content: The string to check against.
check_patterns: The FileCheck checks to use.
check_prefixes: The prefixes to use for the FileCheck checks.
extra_filecheck_options: Extra options to pass to FileCheck.
"""
if PYVERSION < (3, 10):
self.skipTest("FileCheck requires Python 3.10 or later")
opts = Options(
match_filename="-",
check_prefixes=check_prefixes,
**extra_filecheck_options,
)
input_file = FInput(fname="-", content=ir_content)
parser = Parser(opts, StringIO(check_patterns), *pattern_for_opts(opts))
matcher = Matcher(opts, input_file, parser)
matcher.stderr = StringIO()
result = matcher.run()
if result != 0:
self.fail(
f"FileCheck failed:\n{matcher.stderr.getvalue()}\n\n"
f"Check prefixes:\n{check_prefixes}\n\n"
f"Check patterns:\n{check_patterns}\n"
f"IR:\n{ir_content}\n\n"
)


class CUDATestCase(SerialMixin, FileCheckTestCaseMixin, TestCase):
"""
For tests that use a CUDA device. Test methods in a CUDATestCase must not
be run out of module order, because the ContextResettingTestCase may reset
Expand Down
62 changes: 38 additions & 24 deletions numba_cuda/numba/cuda/tests/cudapy/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def test_monotyped(self):

@cuda.jit(sig)
def foo(x, y):
"""
// LLVM: define void
// LLVM-SAME: foo
// LLVM-LABEL: entry:
// LLVM-NEXT: br label %"[[VAL_0:.*]]"
// LLVM-NEXT: [[VAL_0]]:
// LLVM-NEXT: ret void

// ASM: Generated by NVIDIA NVVM Compiler
// ASM: foo
"""
pass

file = StringIO()
Expand All @@ -37,28 +48,43 @@ def foo(x, y):
# Signature in annotation
self.assertIn("(float32, int32)", typeanno)
file.close()
# Function name in LLVM
llvm = foo.inspect_llvm(sig)
self.assertIn("foo", llvm)

# Kernel in LLVM
self.assertIn("define void @", llvm)

asm = foo.inspect_asm(sig)

# Function name in PTX
self.assertIn("foo", asm)
# NVVM inserted comments in PTX
self.assertIn("Generated by NVIDIA NVVM Compiler", asm)
self.assertFileCheckLLVM(foo, sig)
self.assertFileCheckAsm(foo, sig)

def test_polytyped(self):
@cuda.jit
def foo(x, y):
"""
// LLVM: define void
// LLVM-SAME: foo
// LLVM_INT-SAME: i64
// LLVM_INT-SAME: i64
// LLVM_FLOAT-SAME: double
// LLVM_FLOAT-SAME: double

// ASM: Generated by NVIDIA NVVM Compiler
// ASM: .visible
// ASM-SAME: .entry
// ASM-SAME: foo
"""
pass

foo[1, 1](1, 1)
foo[1, 1](1.2, 2.4)

int_sig = (intp, intp)
float_sig = (float64, float64)

self.assertFileCheckLLVM(
foo, int_sig, check_prefixes=["LLVM", "LLVM_INT"]
)
self.assertFileCheckAsm(foo, int_sig, check_prefixes=["ASM"])
self.assertFileCheckLLVM(
foo, float_sig, check_prefixes=["LLVM", "LLVM_FLOAT"]
)
self.assertFileCheckAsm(foo, float_sig, check_prefixes=["ASM"])

file = StringIO()
foo.inspect_types(file=file)
typeanno = file.getvalue()
Expand All @@ -76,14 +102,6 @@ def foo(x, y):
self.assertIn((intp, intp), llvmirs)
self.assertIn((float64, float64), llvmirs)

# Function name in LLVM
self.assertIn("foo", llvmirs[intp, intp])
self.assertIn("foo", llvmirs[float64, float64])

# Kernels in LLVM
self.assertIn("define void @", llvmirs[intp, intp])
self.assertIn("define void @", llvmirs[float64, float64])

asmdict = foo.inspect_asm()

# Signature in assembly dict
Expand All @@ -94,10 +112,6 @@ def foo(x, y):
self.assertIn((intp, intp), asmdict)
self.assertIn((float64, float64), asmdict)

# NVVM inserted in PTX
self.assertIn("foo", asmdict[intp, intp])
self.assertIn("foo", asmdict[float64, float64])

def _test_inspect_sass(self, kernel, name, sass):
# Ensure function appears in output
seen_function = False
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ test = [
"psutil",
"cffi",
"pytest",
"filecheck",
]
test-cu11 = [
"numba-cuda[test]",
Expand Down