Skip to content

Conversation

@ashermancinelli
Copy link
Contributor

@ashermancinelli ashermancinelli commented Jul 22, 2025

Adding pattern-based checking infrastructure allows us to describe the IR and assembly expected under different conditions right next to the code that produces it. This is closer to the LLVM testing standards, and will help is iterate more quickly and add test coverage while we are in the process of vendoring in the components from numba core we need to modify in a cuda-specific way.

This PR changes a couple tests to use the filecheck patterns. For example, in numba-cuda today there tests like self.assertIn('function_name', llvm_ir). Now, we can express the same check in the docstring of the kernel, and all the checks we would like to perform may live inside the kernel they come from.

This is a request-for-comment, as filecheck tests will be meaningfully different from existing numba-cuda tests.

Note also that the filecheck that has been added to the list of dependencies is not the filecheck that is compiled as a part of the LLVM project, but is an attempt to replicate the functionality in a pure python package.
I ran some tests to see how well the usual FileCheck features work with the Python version, and it seems okay.
If we are uncomfortable with this development dependency, we can reevaluate.

TODO:

  • CUDA Test Case should have file check available by default
  • Review type-hint inconsistencies
  • Skip filecheck-based tests in python 3.9
  • Add dummy class for CUDADispatcher when running with the simulator

Another solution I considered would be to structure our tests much more like LLVM lit tests, not just using the filecheck pattern like I did in this patch.

In this alternative, I imagine test kernels would be written in their own files:

# RUN: numba-cuda-to-llvm-ir %s | filecheck %s
# ^^ the above line converts the JIT'ed kernels in this Python module to
#    LLVM IR and checks the check-comments in the module against the IR

from numba.cuda import jit

# CHECK-LABEL: foo
# CHECK: fadd
@jit((float, float))
def test_foo(a, b):
    a = b

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jul 22, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Copy link
Contributor

@isVoid isVoid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand the use of this tool. Also a broader question, the result of the compiler are in fact legal nvvm IRs, should we use NVVM as the prefix instead?

@ashermancinelli
Copy link
Contributor Author

should we use NVVM [instead of LLVM] as the prefix instead?

Sure! Whatever will most precisely differentiate the checks from one another.

@ashermancinelli
Copy link
Contributor Author

@isVoid I tried answering all your questions, but let me speak about the motivation more generally: FileCheck allows us to easily create meaningful tests (in a semi-automated way) to increase the test coverage we have so we feel more secure making larger changes. Consider this test in LLVM Flang. There are many ways to compile the same test, some of which result in the same IR. The same source code can be reused to create 6 or so test cases, some of which share the same checks. There are scripts in the LLVM project to generate and update these tests automatically.

When we vendor in components for the purpose of cuda-specific extensions, instead of vendoring in the unit tests from numba core into numba cuda directly, we can vendor in the kernels and use cuda.jit instead of the regular numba jit, and ensure the IR is the same. Then, as we change the components, we can be sure we have not changed from the IR we were getting with the non-vendored version (or if we have changed the IR, it is in the way we expect).

This is also just a suggestion/RFC. We will discuss this further before submitting any changes. Thank you for the review and please ask more questions if I left anything unclear!

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea here! This looks like a nice way to check the generated code.

I do think we could simplify the implementation and API, and make it a bit closer to the idioms of the rest of the test suite - see comments on the diff for these suggestions.

I also wonder - does FileCheck identify the failing line and provide error messages on stderr? I notice the Matcher class has an stderr member but I haven't dived deeper - I am looking to understand whether the framework can provide more precise error messages when failures occur.

@gmarkall gmarkall added the 4 - Waiting on author Waiting for author to respond to review label Jul 23, 2025
@ashermancinelli
Copy link
Contributor Author

I also wonder - does FileCheck identify the failing line and provide error messages on stderr?

It does - but please note that this is a Python re-implementation of the FileCheck executable program which is a part of the LLVM project, so I am also unfamiliar with the API of this Python module. I will investigate further and see how we can make this more ergonomic. Thank you for the review!

@ashermancinelli ashermancinelli self-assigned this Jul 23, 2025
Adding pattern-based checking infrastructure allows us to describe
the IR and assembly expected under different conditions and expressed
right next to the code that produced it. This is closer to the LLVM
testing standards, and will help is iterate more quickly and add test
coverage while we are in the process of vendoring in components
from numba core.

This is a request for comment, as filecheck tests will be meaningfully
different from existing numba-cuda tests.
@ashermancinelli
Copy link
Contributor Author

/ok to test c45924d

@ashermancinelli ashermancinelli added 3 - Ready for Review Ready for review by team and removed 4 - Waiting on author Waiting for author to respond to review labels Jul 23, 2025
@ashermancinelli
Copy link
Contributor Author

/ok to test df00dc7

@ashermancinelli
Copy link
Contributor Author

/ok to test 0552773

@isVoid
Copy link
Contributor

isVoid commented Jul 23, 2025

@ashermancinelli thanks for the explanation. When we write tests, we are limited by the tool that we have so that we can't test more complicated pattern. This opens up a whole range of options for us. Just a simple grab through the code base with self\.assertIn\(.*, (asm|ptx|sass|ir|llvm|llvmirs|llvmirs.*)\) we have these tests that can potentially reuse this test infra.

Details
/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py
  107,9:         self.assertIn(".target sm_{0}{1}".format(*arch), ptx)
  108,9:         self.assertIn("simple", ptx)
  109,9:         self.assertIn("ave", ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py
  691,17:                 self.assertIn("atom.shared.cas.b64", asm)
  693,17:                 self.assertIn("atom.cas.b64", asm)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_casting.py
  138,13:             self.assertIn(f"cvt.rni.s{size}.f16", ptx)
  160,13:             self.assertIn(f"cvt.rni.u{size}.f16", ptx)
  192,13:             self.assertIn(f"cvt.rn.f16.s{size}", ptx)
  201,13:             self.assertIn(f"cvt.rn.f16.u{size}", ptx)
  229,13:             self.assertIn(f"cvt.{postfix}.f16", ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py
  35,9:         self.assertIn(".visible .entry", ptx)
  48,9:         self.assertIn("func_retval", ptx)
  50,9:         self.assertIn(".visible .func", ptx)
  78,9:         self.assertIn("fma.rn.f32", ptx)
  79,9:         self.assertIn("div.rn.f32", ptx)
  80,9:         self.assertIn("sqrt.rn.f32", ptx)
  85,9:         self.assertIn("fma.rn.ftz.f32", ptx)
  86,9:         self.assertIn("div.approx.ftz.f32", ptx)
  87,9:         self.assertIn("sqrt.approx.ftz.f32", ptx)
  270,9:         self.assertIn(target, ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py
  129,9:         self.assertIn(fname, llvm)
  146,9:         self.assertIn(fname, ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py
  36,9:         self.assertIn("foo", llvm)
  39,9:         self.assertIn("define void @", llvm)
  44,9:         self.assertIn("foo", asm)
  46,9:         self.assertIn("Generated by NVIDIA NVVM Compiler", asm)
  70,9:         self.assertIn((intp, intp), llvmirs)
  71,9:         self.assertIn((float64, float64), llvmirs)
  74,9:         self.assertIn("foo", llvmirs[intp, intp])
  75,9:         self.assertIn("foo", llvmirs[float64, float64])
  78,9:         self.assertIn("define void @", llvmirs[intp, intp])
  79,9:         self.assertIn("define void @", llvmirs[float64, float64])
  107,9:         self.assertIn("S2R", sass)  # Special register to register
  108,9:         self.assertIn("BRA", sass)  # Branch
  109,9:         self.assertIn("EXIT", sass)  # Exit program

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py
  645,9:         self.assertIn("add.f16", ptx)
  672,9:         self.assertIn("fma.rn.f16", ptx)
  697,9:         self.assertIn("sub.f16", ptx)
  722,9:         self.assertIn("mul.f16", ptx)
  767,9:         self.assertIn("neg.f16", ptx)
  790,9:         self.assertIn("abs.f16", ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py
  165,9:         self.assertIn("ld.param", ptx)
  169,9:         self.assertIn("st.global", ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py
  216,9:         self.assertIn(target, ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py
  46,9:         self.assertIn(ptx_instruction, ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_operator.py
  182,17:                 self.assertIn(instr, ptx)
  216,17:                 self.assertIn(instr, ptx)
  259,9:         self.assertIn("neg.f16", ptx)
  266,9:         self.assertIn("abs.f16", ptx)
  400,17:                 self.assertIn(s, ptx)
  435,17:                 self.assertIn(opstring[op], ptx)
  481,17:                 self.assertIn(ops, ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py
  70,17:                 self.assertIn(fragment, ptx)
  77,9:         self.assertIn("fma.rn.f64", ptx)

/workspace/numbast-nvshmem/numba-cuda/numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py
  52,9:         self.assertIn(myconstant, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py
  107,9:         self.assertIn(".target sm_{0}{1}".format(*arch), ptx)
  108,9:         self.assertIn("simple", ptx)
  109,9:         self.assertIn("ave", ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_atomics.py
  691,17:                 self.assertIn("atom.shared.cas.b64", asm)
  693,17:                 self.assertIn("atom.cas.b64", asm)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_casting.py
  138,13:             self.assertIn(f"cvt.rni.s{size}.f16", ptx)
  160,13:             self.assertIn(f"cvt.rni.u{size}.f16", ptx)
  192,13:             self.assertIn(f"cvt.rn.f16.s{size}", ptx)
  201,13:             self.assertIn(f"cvt.rn.f16.u{size}", ptx)
  229,13:             self.assertIn(f"cvt.{postfix}.f16", ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_compiler.py
  35,9:         self.assertIn(".visible .entry", ptx)
  48,9:         self.assertIn("func_retval", ptx)
  50,9:         self.assertIn(".visible .func", ptx)
  78,9:         self.assertIn("fma.rn.f32", ptx)
  79,9:         self.assertIn("div.rn.f32", ptx)
  80,9:         self.assertIn("sqrt.rn.f32", ptx)
  85,9:         self.assertIn("fma.rn.ftz.f32", ptx)
  86,9:         self.assertIn("div.approx.ftz.f32", ptx)
  87,9:         self.assertIn("sqrt.approx.ftz.f32", ptx)
  270,9:         self.assertIn(target, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_device_func.py
  129,9:         self.assertIn(fname, llvm)
  146,9:         self.assertIn(fname, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_inspect.py
  36,9:         self.assertIn("foo", llvm)
  39,9:         self.assertIn("define void @", llvm)
  44,9:         self.assertIn("foo", asm)
  46,9:         self.assertIn("Generated by NVIDIA NVVM Compiler", asm)
  70,9:         self.assertIn((intp, intp), llvmirs)
  71,9:         self.assertIn((float64, float64), llvmirs)
  74,9:         self.assertIn("foo", llvmirs[intp, intp])
  75,9:         self.assertIn("foo", llvmirs[float64, float64])
  78,9:         self.assertIn("define void @", llvmirs[intp, intp])
  79,9:         self.assertIn("define void @", llvmirs[float64, float64])
  107,9:         self.assertIn("S2R", sass)  # Special register to register
  108,9:         self.assertIn("BRA", sass)  # Branch
  109,9:         self.assertIn("EXIT", sass)  # Exit program

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py
  645,9:         self.assertIn("add.f16", ptx)
  672,9:         self.assertIn("fma.rn.f16", ptx)
  697,9:         self.assertIn("sub.f16", ptx)
  722,9:         self.assertIn("mul.f16", ptx)
  767,9:         self.assertIn("neg.f16", ptx)
  790,9:         self.assertIn("abs.f16", ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py
  165,9:         self.assertIn("ld.param", ptx)
  169,9:         self.assertIn("st.global", ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py
  216,9:         self.assertIn(target, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_minmax.py
  46,9:         self.assertIn(ptx_instruction, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_operator.py
  182,17:                 self.assertIn(instr, ptx)
  216,17:                 self.assertIn(instr, ptx)
  259,9:         self.assertIn("neg.f16", ptx)
  266,9:         self.assertIn("abs.f16", ptx)
  400,17:                 self.assertIn(s, ptx)
  435,17:                 self.assertIn(opstring[op], ptx)
  481,17:                 self.assertIn(ops, ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/cudapy/test_optimization.py
  70,17:                 self.assertIn(fragment, ptx)
  77,9:         self.assertIn("fma.rn.f64", ptx)

/workspace/numba-cuda/numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py
  52,9:         self.assertIn(myconstant, ptx)

self,
ir_producer: CUDADispatcher,
signature: tuple[type, ...] | None = None,
check_prefixes: list[str] = ["ASM"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This use of a list as a default argument is probably safe as it's not mutated, but it always sets off alarm bells to see a mutable default in a Python function definition. I think you could either use None as in https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments, or have a tuple for the default instead (I assume filecheck will accept a tuple as well).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A tuple or frozenset would be okay with me, let me try that out. Thanks!

@ashermancinelli
Copy link
Contributor Author

/ok to test 8d6dcb8

@ashermancinelli
Copy link
Contributor Author

The filecheck module compatible with python 3.9 is too old to be used here.

@ashermancinelli ashermancinelli added 2 - In Progress Currently a work in progress and removed 3 - Ready for Review Ready for review by team labels Jul 23, 2025
@ashermancinelli
Copy link
Contributor Author

/ok to test 8fb7920

@ashermancinelli ashermancinelli added 3 - Ready for Review Ready for review by team and removed 2 - In Progress Currently a work in progress labels Jul 23, 2025
@ashermancinelli
Copy link
Contributor Author

/ok to test dd6eb32

@gmarkall
Copy link
Contributor

/ok to test

@gmarkall
Copy link
Contributor

I think the PR is good. I am just running another check with a merge of main to make sure that all is well with pytest being the test framework before merging this PR.

@gmarkall gmarkall added 5 - Ready to merge Testing and reviews complete, ready to merge and removed 3 - Ready for Review Ready for review by team labels Jul 24, 2025
@gmarkall gmarkall merged commit 74439bd into NVIDIA:main Jul 24, 2025
39 checks passed
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Jul 31, 2025
- [NFC] FileCheck tests check all overloads (NVIDIA#354)
- [REVIEW][NFC] Vendor in serialize to allow for future CUDA-specific refactoring and changes (NVIDIA#349)
- Vendor in usecases used in testing (NVIDIA#359)
- Add thirdparty tests of numba extensions (NVIDIA#348)
- Support running tests in parallel (NVIDIA#350)
- Add more debuginfo tests (NVIDIA#358)
- [REVIEW][NFC] Vendor in the Cache, CacheImpl used by CUDACache and CUDACacheImpl to allow for future CUDA-specific refactoring and changes (NVIDIA#334)
- [NFC] Vendor in Dispatcher as CUDADispatcher to allow for future CUDA-specific customization (NVIDIA#338)
- Vendor in BaseNativeLowering and BaseLower for CUDA-specific customizations (NVIDIA#329)
- [REVIEW] Vendor in the CompilerBase used by CUDACompiler to allow for future CUDA-specific refactoring and changes (NVIDIA#322)
- Vendor in Codegen and CodeLibrary for CUDA-specific customization (NVIDIA#327)
- Disable tests that deadlock due to NVIDIA#317 (NVIDIA#356)
- FIX: Add type check for shape elements in DeviceNDArrayBase constructor (NVIDIA#352)
- Merge pull request NVIDIA#265 from lakshayg/fp16-support
- Add performance warning
- Fix tests
- Create and register low++ bindings for float16
- Create typing/target registries for float16
- Replace Numbast generated lower_casts
- Replace Numbast generated operators
- Alias __half to numba.core.types.float16
- Generate fp16 bindings using numbast
- Remove existing fp16 logic
- [REVIEW][NFC] Vendor in the utils and cgutils to allow for future CUDA-specific refactoring and changes (NVIDIA#340)
- [RFC,TESTING] Add filecheck test infrastructure (NVIDIA#342)
- Migrate test infra to pytest (NVIDIA#347)
- Add .vscode to gitignore (NVIDIA#344)
- [NFC] Add dev dependencies to project config (NVIDIA#341)
- Allow Inspection of Link-Time Optimized PTX (NVIDIA#326)
- [NFC] Vendor in DIBuilder used by CUDADIBuilder (NVIDIA#332)
- Add guidance on setting up pre-commit (NVIDIA#339)
- [Refactor][NFC] Vendor in MinimalCallConv (NVIDIA#333)
- [Refactor][NFC] Vendor in BaseCallConv (NVIDIA#324)
- [REVIEW] Vendor in CompileResult as CUDACompileResult to allow for future CUDA-specific customizations (NVIDIA#325)
@gmarkall gmarkall mentioned this pull request Jul 31, 2025
gmarkall added a commit that referenced this pull request Jul 31, 2025
- [NFC] FileCheck tests check all overloads (#354)
- [REVIEW][NFC] Vendor in serialize to allow for future CUDA-specific
refactoring and changes (#349)
- Vendor in usecases used in testing (#359)
- Add thirdparty tests of numba extensions (#348)
- Support running tests in parallel (#350)
- Add more debuginfo tests (#358)
- [REVIEW][NFC] Vendor in the Cache, CacheImpl used by CUDACache and
CUDACacheImpl to allow for future CUDA-specific refactoring and changes
(#334)
- [NFC] Vendor in Dispatcher as CUDADispatcher to allow for future
CUDA-specific customization (#338)
- Vendor in BaseNativeLowering and BaseLower for CUDA-specific
customizations (#329)
- [REVIEW] Vendor in the CompilerBase used by CUDACompiler to allow for
future CUDA-specific refactoring and changes (#322)
- Vendor in Codegen and CodeLibrary for CUDA-specific customization
(#327)
- Disable tests that deadlock due to #317 (#356)
- FIX: Add type check for shape elements in DeviceNDArrayBase
constructor (#352)
- Merge pull request #265 from lakshayg/fp16-support
- Add performance warning
- Fix tests
- Create and register low++ bindings for float16
- Create typing/target registries for float16
- Replace Numbast generated lower_casts
- Replace Numbast generated operators
- Alias __half to numba.core.types.float16
- Generate fp16 bindings using numbast
- Remove existing fp16 logic
- [REVIEW][NFC] Vendor in the utils and cgutils to allow for future
CUDA-specific refactoring and changes (#340)
- [RFC,TESTING] Add filecheck test infrastructure (#342)
- Migrate test infra to pytest (#347)
- Add .vscode to gitignore (#344)
- [NFC] Add dev dependencies to project config (#341)
- Allow Inspection of Link-Time Optimized PTX (#326)
- [NFC] Vendor in DIBuilder used by CUDADIBuilder (#332)
- Add guidance on setting up pre-commit (#339)
- [Refactor][NFC] Vendor in MinimalCallConv (#333)
- [Refactor][NFC] Vendor in BaseCallConv (#324)
- [REVIEW] Vendor in CompileResult as CUDACompileResult to allow for
future CUDA-specific customizations (#325)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

5 - Ready to merge Testing and reviews complete, ready to merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants