Skip to content

Conversation

@gmarkall
Copy link
Contributor

The original linking implementation for linkable code in device declarations did not consider calls inside callees; this change recurses through the typing to find all calls requiring linkable code.

The original linking implementation for linkable code in device
declarations did not consider calls inside callees; this change recurses
through the typing to find all calls requiring linkable code.
@gmarkall gmarkall added the 2 - In Progress Currently a work in progress label Feb 26, 2025
Comment on lines 51 to 68
# The typemap of the function includes calls, so we can traverse it to find
# the references we need.
for name, v in cres.fndesc.typemap.items():

# CUDADispatchers represent a call to a device function, so we need to
# look up the linkable code for those recursively.
if isinstance(v, cuda_types.CUDADispatcher):
# We need to locate the signature of the call so we can find the
# correct overload.
for call, sig in cres.fndesc.calltypes.items():
if isinstance(call, ir.Expr) and call.op == 'call':
# There will likely be multiple calls in the typemap; we
# can uniquely identify the relevant one using its SSA
# name.
if call.func.name == name:
called_cres = v.dispatcher.overloads[sig.args]
called_link_objects = get_cres_link_objects(called_cres)
link_objects.update(called_link_objects)
Copy link
Contributor

@isVoid isVoid Mar 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool. I learnt a few things by reading through this section. Do you think the below simplifies the code and reduces the code complexity for a little bit?

I made a PR here:
gmarkall#4

I think this reduces the size of the list for both of the nested for-loop. This is proportional to O(num_calls^2), not O(num_typings^2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I've incorporated your changes. I don't think there's much of a performance impact (the number of calls and typings won't be that large), but I think your modifications improved the readability of the code.

@gmarkall gmarkall added 4 - Waiting on reviewer Waiting for reviewer to respond to author 2 - In Progress Currently a work in progress and removed 2 - In Progress Currently a work in progress 4 - Waiting on reviewer Waiting for reviewer to respond to author labels Mar 5, 2025
Copy link
Contributor

@isVoid isVoid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@gmarkall gmarkall merged commit 9f0d154 into NVIDIA:main Mar 6, 2025
31 checks passed
gmarkall added a commit to gmarkall/numba-cuda that referenced this pull request Mar 6, 2025
- Fix linking of external code from callees (NVIDIA#137)
- Try using a newer branch workflow (NVIDIA#148)
- Move publish step out of `wheels-build.yaml` (NVIDIA#147)
- Upload wheels to PyPI from GitHub-hosted runner (NVIDIA#142)
- Add paddle to interoperability chapter (NVIDIA#144)
- Fix the debug info of GridGroup type (NVIDIA#131)
- Remove dead `prepare_cuda_kernel()` (NVIDIA#130)
- Add a CUDA DI Builder (NVIDIA#104)
- dont launch extra kernels when stats counting is disabled (NVIDIA#127)
- Fixup debug metadata in kernel fixup (NVIDIA#97)
- Implement debuginfo bool name fix (numba/numba#9888) in numba-cuda (NVIDIA#106)
@gmarkall gmarkall mentioned this pull request Mar 6, 2025
gmarkall added a commit that referenced this pull request Mar 6, 2025
- Fix linking of external code from callees (#137)
- Try using a newer branch workflow (#148)
- Move publish step out of `wheels-build.yaml` (#147)
- Upload wheels to PyPI from GitHub-hosted runner (#142)
- Add paddle to interoperability chapter (#144)
- Fix the debug info of GridGroup type (#131)
- Remove dead `prepare_cuda_kernel()` (#130)
- Add a CUDA DI Builder (#104)
- dont launch extra kernels when stats counting is disabled (#127)
- Fixup debug metadata in kernel fixup (#97)
- Implement debuginfo bool name fix (numba/numba#9888) in numba-cuda (#106)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2 - In Progress Currently a work in progress

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants