Skip to content

Conversation

@isVoid
Copy link
Contributor

@isVoid isVoid commented Sep 20, 2025

This PR adds compile_all API. Instead of only returning the compiled kernel, it also compiles cuda source files for its external linking files. In the case that external files are not compilable (such as cubin, ptx, LTOIR etc.), they will be passed through into the return list.

Testing of compile_all completely is tough due to combinatorial complexity from the number of arguments. Since it's mainly a passthrough of compile with additional nvrtc pass on external linking files, I reused all existing tests in TestCompile to make sure the API is consistent with previous APIs.

The default arguments of compile_all is the same as compile.

When lineinfo or debug is set to true, the nvrtc flag for either options are set accordingly for external files.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 20, 2025

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@isVoid
Copy link
Contributor Author

isVoid commented Sep 20, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 23, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 23, 2025

/ok to test

@skip_on_cudasim(reason="Simulator does not support linkable code")
def test_linkable_code_from_path_or_obj(self):
files_kind = [
(test_device_functions_a, cuda.Archive),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be guarded by TEST_BIN_DIR ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be run if TEST_BIN_DIR doesn't exist.

raise NotImplementedError(f"Unsupported output type: {output}")

if forceinline and output != "ltoir":
raise ValueError("Can only designate forced inlining in LTO-IR")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a hard error? Is forceinline a guarantee or can the compiler still ignore it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imagine this is because there's no way to designate the inline information when you specify output as PTX. The information can only persist via LTOIR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct - you can't inline PTX. This looks like new code on the diff, but this PR is only really moving it from the original compile() implementation.

codes = [code]

# linking_files
lto = output == "ltoir"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd consider renaming lto to do_lto or is_lto to more clearly indicate that its a flag controlling behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name here means whether the output should be ltoir, so is_ltoir is the best fit here IMO.

# linking_files
lto = output == "ltoir"
for path_or_obj in lib._linking_files:
obj = LinkableCode.from_path_or_obj(path_or_obj)
Copy link
Contributor

@rparolin rparolin Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable is called path_or_obj but the code appears to handle cu and obj use cases. Should the variable be renamed or are you missing the path handling use case?

Copy link
Contributor Author

@isVoid isVoid Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The objects in lib._linking_files can be arbitrary linkable code object or paths, among which only cuda source files require compilation before it's feedable to the linker. (That's the implied assumption of return objects for compile_all, everything returning from the API should be passable to linker without additional processing). Therefore we are special casing below, compile them with nvrtc before returning it to user.

rparolin
rparolin previously approved these changes Sep 24, 2025
Copy link
Contributor

@rparolin rparolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a couple comments for you to consider. Looks good overall.

@isVoid isVoid requested a review from rparolin September 26, 2025 22:01
rparolin
rparolin previously approved these changes Sep 26, 2025
@isVoid
Copy link
Contributor Author

isVoid commented Sep 26, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 26, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 29, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 30, 2025

/ok to test

@isVoid
Copy link
Contributor Author

isVoid commented Sep 30, 2025

/ok to test



def compile(src, name, cc, ltoir=False):
def compile(src, name, cc, ltoir=False, lineinfo=False, debug=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit of an oversight that we never passed these flags to NVRTC before - have you noticed if this now enables debugging / profiling of NVRTC-compiled code used in Numba kernels with these changes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a discussion offline on this. I think this will serve as part of the solution.

Comment on lines 996 to 1000
"""Return a list of PTX/LTO-IR for kernel as well as external functions depended by the kernel.
If external functions are cuda c/c++ source, they will be compiled with NVRTC. The output code
kind is the same as the `output` parameter.
Otherwise, they will be passed through to the return list.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this docstring will appear in the docs, I think it needs some modification to fit into the context of the documentation on the compilation APIs:

Suggested change
"""Return a list of PTX/LTO-IR for kernel as well as external functions depended by the kernel.
If external functions are cuda c/c++ source, they will be compiled with NVRTC. The output code
kind is the same as the `output` parameter.
Otherwise, they will be passed through to the return list.
"""
"""Similar to ``compile()``, but returns a list of PTX codes/LTO-IRs for
the compiled function and the external functions it depends on.
If external functions are CUDA C++ source, they will be compiled with
NVRTC. Other kinds of external function code (e.g. cubins, fatbins, etc.)
will be added directly to the return list. The output code kind is
determined by the ``output`` parameter, and defaults to ``"ltoir"``.
"""

This change is aimed at pointing out the differences between compile() and compile_all().

The docs change to include this is also required - this could be done by adding

.. autofunction:: numba.cuda.compile_all

around line 126 of docs/source/reference/host.rst.

self._linking_files.add(path_or_obj)

@property
def linking_files_as_obj(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems a bit strange that the property isn't just called linking_files, but maybe I'm not understanding some nuance here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I believed at one point I wanted to use this property to normalize object and paths into LinkableCode. But now this logic lives in LinkableCode factory functions.

return x + y

args = (float32, float32)
ptx, resty = compile_ptx(add, args, device=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes here work - I think I would have been inclined to use a pattern like:

    def _test_device_function(self, compile_function):
        def add(x, y):
            return x + y

        args = (float32, float32)

        ptx, resty = compile_function(add, args, device=True)

        # Device functions take a func_retval parameter for storing the
        # returned value in by reference
        self.assertIn("func_retval", ptx)
        # .visible .func is used to denote a device function
        self.assertIn(".visible .func", ptx)
        # .visible .entry would denote the presence of a global function
        self.assertNotIn(".visible .entry", ptx)
        # Inferred return type as expected?
        self.assertEqual(resty, float32)

        # Check that function's output matches signature
        sig_int32 = int32(int32, int32)
        ptx, resty = compile_ptx(add, sig_int32, device=True)
        self.assertEqual(resty, int32)

        sig_int16 = int16(int16, int16)
        ptx, resty = compile_ptx(add, sig_int16, device=True)
        self.assertEqual(resty, int16)
        # Using string as signature
        sig_string = "uint32(uint32, uint32)"
        ptx, resty = compile_ptx(add, sig_string, device=True)
        self.assertEqual(resty, uint32)

    def test_device_function(self):
        self._test_device_function(compile_ptx)

    def test_device_function_all(self):
        def compile_all_wrapper(*args, **kwargs):
            kwargs["abi"] = "c"
            kwargs["output"] = "ptx"
            ptx_list, resty = compile_all(*args, **kwargs)
            self.assertEqual(len(ptx_list), 1)
            return ptx_list[0], resty

        self._test_device_function(compile_all_wrapper)

to test compile_all() without duplicating all the test code, but I think this needn't necessarily be changed.

(To make the changes from the original main code upstream a bit clearer, they are:

  • Rename test_device_function() to _test_device_function() and add the compile_function parameter
  • Replace the compile_ptx() call with a call to compile_function()
  • Add the test_device_function() and test_device_function_all() methods, where test_device_function_all() emulates the API of compile_ptx() with a wrapper around compile_all().

)

Copy link
Contributor

@gmarkall gmarkall left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I think this looks good in general. We need some changes for documentation (which are noted on the diff):

  • To add the compile_all() function to the reference documentation
  • To modify compile_all()'s docstring so it reads well in the context of the rest of the documentation.

I made a comment about the duplications introduced in the tests for compile*() and compile_all(), but I think that needn't necessarily be changed.

@gmarkall gmarkall added the 4 - Waiting on author Waiting for author to respond to review label Oct 2, 2025
@gmarkall
Copy link
Contributor

gmarkall commented Oct 2, 2025

/ok to test

@gmarkall gmarkall added 5 - Ready to merge Testing and reviews complete, ready to merge and removed 4 - Waiting on author Waiting for author to respond to review labels Oct 2, 2025
@gmarkall gmarkall merged commit ff7637e into NVIDIA:main Oct 2, 2025
56 checks passed
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Nov 20, 2025
- Add support for cache-hinted load and store operations (NVIDIA#587)
- Add more thirdparty tests (NVIDIA#586)
- Add sphinx-lint to pre-commit and fix errors (NVIDIA#597)
- Add DWARF variant part support for polymorphic variables in CUDA debug info (NVIDIA#544)
- chore: clean up dead workaround for unavailable `lru_cache` (NVIDIA#598)
- chore(docs): format types docs (NVIDIA#596)
- refactor: decouple `Context` from `Stream` and `Event` objects (NVIDIA#579)
- Fix freezing in of constant arrays with negative strides (NVIDIA#589)
- Update tests to accept variants of generated PTX (NVIDIA#585)
- refactor: replace device functionality with `cuda.core` APIs (NVIDIA#581)
- Move frontend tests to `cudapy` namespace (NVIDIA#558)
- Generalize the concurrency group for main merges (NVIDIA#582)
- ci: move pre-commit checks to pre commit action (NVIDIA#577)
- chore(pixi): set up doc builds; remove most `build-conda` dependencies (NVIDIA#574)
- ci: ensure that python version in ci matches matrix (NVIDIA#575)
- Fix the `cuda.is_supported_version()` API (NVIDIA#571)
- Fix checks on main (NVIDIA#576)
- feat: add `math.nextafter` (NVIDIA#543)
- ci: replace conda testing with pixi (NVIDIA#554)
- [CI] Run PR workflow on merge to main (NVIDIA#572)
- Propose Alternative Module Path for `ext_types` and Maintain `numba.cuda.types.bfloat16` Import API (NVIDIA#569)
- test: enable fail-on-warn and clean up resulting failures (NVIDIA#529)
- [Refactor][NFC] Vendor-in compiler_lock for future CUDA-specific changes (NVIDIA#565)
- Fix registration with Numba, vendor MakeFunctionToJITFunction tests (NVIDIA#566)
- [Refactor][NFC][Cleanups] Update imports to upstream numba to use the numba.cuda modules (NVIDIA#561)
- test: refactor process-based tests to use concurrent futures in order to simplify tests (NVIDIA#550)
- test: revert back to ipc futures that await each iteration (NVIDIA#564)
- chore(deps): move to self-contained pixi.toml to avoid mixed-pypi-pixi environments (NVIDIA#551)
- [Refactor][NFC] Vendor-in errors for future CUDA-specific changes (NVIDIA#534)
- Remove dependencies on target_extension for CUDA target (NVIDIA#555)
- Relax the pinning to `cuda-core` to allow it floating across minor releases (NVIDIA#559)
- [WIP] Port numpy reduction tests to CUDA (NVIDIA#523)
- ci: add timeout to avoid blocking the job queue (NVIDIA#556)
- Handle `cuda.core.Stream` in driver operations (NVIDIA#401)
- feat: add support for `math.exp2` (NVIDIA#541)
- Vendor in types and datamodel for CUDA-specific changes (NVIDIA#533)
- refactor: cleanup device constructor (NVIDIA#548)
- bench: add cupy to array constructor kernel launch benchmarks (NVIDIA#547)
- perf: cache dimension computations (NVIDIA#542)
- perf: remove duplicated size computation (NVIDIA#537)
- chore(perf): add torch to benchmark (NVIDIA#539)
- test: speed up ipc tests by ~6.5x (NVIDIA#527)
- perf: speed up kernel launch (NVIDIA#510)
- perf: remove context threading in various pointer abstractions (NVIDIA#536)
- perf: reduce the number of `__cuda_array_interface__` accesses (NVIDIA#538)
- refactor: remove unnecessary custom map and set implementations (NVIDIA#530)
- [Refactor][NFC] Vendor-in vectorize decorators for future CUDA-specific changes (NVIDIA#513)
- test: add benchmarks for kernel launch for reproducibility (NVIDIA#528)
- test(pixi): update pixi testing command to work with the new `testing` directory (NVIDIA#522)
- refactor: fully remove `USE_NV_BINDING` (NVIDIA#525)
- Draft: Vendor in the IR module (NVIDIA#439)
- pyproject.toml: add search path for Pyrefly (NVIDIA#524)
- Vendor in numba.core.typing for CUDA-specific changes (NVIDIA#473)
- Use numba.config when available, otherwise use numba.cuda.config (NVIDIA#497)
- [MNT] Drop NUMBA_CUDA_USE_NVIDIA_BINDING; always use cuda.core and cuda.bindings as fallback (NVIDIA#479)
- Vendor in dispatcher, entrypoints, pretty_annotate for CUDA-specific changes (NVIDIA#502)
- build: allow parallelization of nvcc testing builds (NVIDIA#521)
- chore(dev-deps): add pixi (NVIDIA#505)
- Vendor the imputils module for CUDA refactoring (NVIDIA#448)
- Don't use `MemoryLeakMixin` for tests that don't use NRT (NVIDIA#519)
- Switch back to stable cuDF release in thirdparty tests (NVIDIA#518)
- Updating .gitignore with binaries in the `testing` folder (NVIDIA#516)
- Remove some unnecessary uses of ContextResettingTestCase (NVIDIA#507)
- Vendor in _helperlib cext for CUDA-specific changes (NVIDIA#512)
- Vendor in typeconv for future CUDA-specific changes (NVIDIA#499)
- [Refactor][NFC] Vendor-in numba.cpython modules for future CUDA-specific changes (NVIDIA#493)
- [Refactor][NFC] Vendor-in numba.np modules for future CUDA-specific changes (NVIDIA#494)
- Make the CUDA target the default for CUDA overload decorators (NVIDIA#511)
- Remove C extension loading hacks (NVIDIA#506)
- Ensure NUMBA can manipulate memory from CUDA graphs before the graph is launched (NVIDIA#437)
- [Refactor][NFC] Vendor-in core Numba analysis utils for CUDA-specific changes (NVIDIA#433)
- Fix Bf16 Test OB Error (NVIDIA#509)
- Vendor in components from numba.core.runtime for CUDA-specific changes (NVIDIA#498)
- [Refactor] Vendor in _dispatcher, _devicearray, mviewbuf C extension for CUDA-specific customization (NVIDIA#373)
- [MNT] Managed UM memset fallback and skip CUDA IPC tests on WSL2 (NVIDIA#488)
- Improve debug value range coverage (NVIDIA#461)
- Add `compile_all` API (NVIDIA#484)
- Vendor in core.registry for CUDA-specific changes (NVIDIA#485)
- [Refactor][NFC] Vendor in numba.misc for CUDA-specific changes (NVIDIA#457)
- Vendor in optional, boxing for CUDA-specific changes, fix dangling imports (NVIDIA#476)
- [test] Remove dependency on cpu_target (NVIDIA#490)
- Change dangling imports of numba.core.lowering to numba.cuda.lowering (NVIDIA#475)
- [test] Use numpy's tolerance for float16 (NVIDIA#491)
- [Refactor][NFC] Vendor-in numba.extending for future CUDA-specific changes (NVIDIA#466)
- [Refactor][NFC] Vendor-in more cpython registries for future CUDA-specific changes (NVIDIA#478)
@gmarkall gmarkall mentioned this pull request Nov 20, 2025
gmarkall added a commit that referenced this pull request Nov 20, 2025
- Add support for cache-hinted load and store operations (#587)
- Add more thirdparty tests (#586)
- Add sphinx-lint to pre-commit and fix errors (#597)
- Add DWARF variant part support for polymorphic variables in CUDA debug
info (#544)
- chore: clean up dead workaround for unavailable `lru_cache` (#598)
- chore(docs): format types docs (#596)
- refactor: decouple `Context` from `Stream` and `Event` objects (#579)
- Fix freezing in of constant arrays with negative strides (#589)
- Update tests to accept variants of generated PTX (#585)
- refactor: replace device functionality with `cuda.core` APIs (#581)
- Move frontend tests to `cudapy` namespace (#558)
- Generalize the concurrency group for main merges (#582)
- ci: move pre-commit checks to pre commit action (#577)
- chore(pixi): set up doc builds; remove most `build-conda` dependencies
(#574)
- ci: ensure that python version in ci matches matrix (#575)
- Fix the `cuda.is_supported_version()` API (#571)
- Fix checks on main (#576)
- feat: add `math.nextafter` (#543)
- ci: replace conda testing with pixi (#554)
- [CI] Run PR workflow on merge to main (#572)
- Propose Alternative Module Path for `ext_types` and Maintain
`numba.cuda.types.bfloat16` Import API (#569)
- test: enable fail-on-warn and clean up resulting failures (#529)
- [Refactor][NFC] Vendor-in compiler_lock for future CUDA-specific
changes (#565)
- Fix registration with Numba, vendor MakeFunctionToJITFunction tests
(#566)
- [Refactor][NFC][Cleanups] Update imports to upstream numba to use the
numba.cuda modules (#561)
- test: refactor process-based tests to use concurrent futures in order
to simplify tests (#550)
- test: revert back to ipc futures that await each iteration (#564)
- chore(deps): move to self-contained pixi.toml to avoid mixed-pypi-pixi
environments (#551)
- [Refactor][NFC] Vendor-in errors for future CUDA-specific changes
(#534)
- Remove dependencies on target_extension for CUDA target (#555)
- Relax the pinning to `cuda-core` to allow it floating across minor
releases (#559)
- [WIP] Port numpy reduction tests to CUDA (#523)
- ci: add timeout to avoid blocking the job queue (#556)
- Handle `cuda.core.Stream` in driver operations (#401)
- feat: add support for `math.exp2` (#541)
- Vendor in types and datamodel for CUDA-specific changes (#533)
- refactor: cleanup device constructor (#548)
- bench: add cupy to array constructor kernel launch benchmarks (#547)
- perf: cache dimension computations (#542)
- perf: remove duplicated size computation (#537)
- chore(perf): add torch to benchmark (#539)
- test: speed up ipc tests by ~6.5x (#527)
- perf: speed up kernel launch (#510)
- perf: remove context threading in various pointer abstractions (#536)
- perf: reduce the number of `__cuda_array_interface__` accesses (#538)
- refactor: remove unnecessary custom map and set implementations (#530)
- [Refactor][NFC] Vendor-in vectorize decorators for future
CUDA-specific changes (#513)
- test: add benchmarks for kernel launch for reproducibility (#528)
- test(pixi): update pixi testing command to work with the new `testing`
directory (#522)
- refactor: fully remove `USE_NV_BINDING` (#525)
- Draft: Vendor in the IR module (#439)
- pyproject.toml: add search path for Pyrefly (#524)
- Vendor in numba.core.typing for CUDA-specific changes (#473)
- Use numba.config when available, otherwise use numba.cuda.config
(#497)
- [MNT] Drop NUMBA_CUDA_USE_NVIDIA_BINDING; always use cuda.core and
cuda.bindings as fallback (#479)
- Vendor in dispatcher, entrypoints, pretty_annotate for CUDA-specific
changes (#502)
- build: allow parallelization of nvcc testing builds (#521)
- chore(dev-deps): add pixi (#505)
- Vendor the imputils module for CUDA refactoring (#448)
- Don't use `MemoryLeakMixin` for tests that don't use NRT (#519)
- Switch back to stable cuDF release in thirdparty tests (#518)
- Updating .gitignore with binaries in the `testing` folder (#516)
- Remove some unnecessary uses of ContextResettingTestCase (#507)
- Vendor in _helperlib cext for CUDA-specific changes (#512)
- Vendor in typeconv for future CUDA-specific changes (#499)
- [Refactor][NFC] Vendor-in numba.cpython modules for future
CUDA-specific changes (#493)
- [Refactor][NFC] Vendor-in numba.np modules for future CUDA-specific
changes (#494)
- Make the CUDA target the default for CUDA overload decorators (#511)
- Remove C extension loading hacks (#506)
- Ensure NUMBA can manipulate memory from CUDA graphs before the graph
is launched (#437)
- [Refactor][NFC] Vendor-in core Numba analysis utils for CUDA-specific
changes (#433)
- Fix Bf16 Test OB Error (#509)
- Vendor in components from numba.core.runtime for CUDA-specific changes
(#498)
- [Refactor] Vendor in _dispatcher, _devicearray, mviewbuf C extension
for CUDA-specific customization (#373)
- [MNT] Managed UM memset fallback and skip CUDA IPC tests on WSL2
(#488)
- Improve debug value range coverage (#461)
- Add `compile_all` API (#484)
- Vendor in core.registry for CUDA-specific changes (#485)
- [Refactor][NFC] Vendor in numba.misc for CUDA-specific changes (#457)
- Vendor in optional, boxing for CUDA-specific changes, fix dangling
imports (#476)
- [test] Remove dependency on cpu_target (#490)
- Change dangling imports of numba.core.lowering to numba.cuda.lowering
(#475)
- [test] Use numpy's tolerance for float16 (#491)
- [Refactor][NFC] Vendor-in numba.extending for future CUDA-specific
changes (#466)
- [Refactor][NFC] Vendor-in more cpython registries for future
CUDA-specific changes (#478)

<!--

Thank you for contributing to numba-cuda :)

Here are some guidelines to help the review process go smoothly.

1. Please write a description in this text box of the changes that are
being
   made.

2. Please ensure that you have written units tests for the changes
made/features
   added.

3. If you are closing an issue please use one of the automatic closing
words as
noted here:
https://help.github.com/articles/closing-issues-using-keywords/

4. If your pull request is not ready for review but you want to make use
of the
continuous integration testing facilities please label it with `[WIP]`.

5. If your pull request is ready to be reviewed without requiring
additional
work on top of it, then remove the `[WIP]` label (if present) and
replace
it with `[REVIEW]`. If assistance is required to complete the
functionality,
for example when the C/C++ code of a feature is complete but Python
bindings
are still required, then add the label `[HELP-REQ]` so that others can
triage
and assist. The additional changes then can be implemented on top of the
same PR. If the assistance is done by members of the rapidsAI team, then
no
additional actions are required by the creator of the original PR for
this,
otherwise the original author of the PR needs to give permission to the
person(s) assisting to commit to their personal fork of the project. If
that
doesn't happen then a new PR based on the code of the original PR can be
opened by the person assisting, which then will be the PR that will be
   merged.

6. Once all work has been done and review has taken place please do not
add
features or make changes out of the scope of those requested by the
reviewer
(doing this just add delays as already reviewed code ends up having to
be
re-reviewed/it is hard to tell what is new etc!). Further, please do not
rebase your branch on main/force push/rewrite history, doing any of
these
   causes the context of any comments made by reviewers to be lost. If
   conflicts occur against main they should be resolved by merging main
   into the branch used for making the pull request.

Many thanks in advance for your cooperation!

-->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

5 - Ready to merge Testing and reviews complete, ready to merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants