Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix launch bounds in spatial transformer #13188

Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Nov 8, 2018

Description

Without launch_bounds compiler is not required to use small enough number of registers to fit 1024 threads per block. Our internal CI with CUDA 10 build was failing on V100 because of this.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR:

  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Added launch_bounds guards around BilinearSampling[Forward,Backward]Kernel to ensure that the compiled operator works on each supported GPU.

Comments

@harshp8l
Copy link
Contributor

harshp8l commented Nov 9, 2018

@mxnet-label-bot add [pr-awaiting-review]

@marcoabreu marcoabreu added the pr-awaiting-review PR is waiting for code review label Nov 9, 2018
@stu1130
Copy link
Contributor

stu1130 commented Nov 20, 2018

@samskalicky @access2rohit @anirudh2290 could you please review it thanks!

@vandanavk
Copy link
Contributor

@mxnet-label-bot add [Operator]

@apeforest for review

@ptrendx ptrendx force-pushed the pr_bilinear_backward_launch_bounds branch from f322d69 to c184b07 Compare December 10, 2018 18:22
@Roshrini
Copy link
Member

@eric-haibin-lin @anirudh2290 @samskalicky can you please review this PR?

@samskalicky
Copy link
Contributor

@ptrendx was there a corresponding issue filed for the V100 CI failure?

@ptrendx
Copy link
Member Author

ptrendx commented Dec 21, 2018

It was our internal (NVIDIA's) CI failure, not external MXNet's CI, so no, there was not an issue for it.

Copy link
Contributor

@samskalicky samskalicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptrendx Changes look straightfoward, no issues but I dont quite understand the fix.

Is 1 thread the correct value? Why choose that value (ie. how about 2)?

Please add some comment regarding the reasoning for the 1 thread value.

Are there unit tests to check this change, how can we validate the correctness?

@ptrendx
Copy link
Member Author

ptrendx commented Dec 21, 2018

It is 1 block, not 1 thread. Basically what this change does is tell the compiler: I'm going to run this kernel with cuda::maxThreadsPerBlock threads and so make sure that at least 1 block fits on SM. It does not limit the kernel to run only 1 block, this is minimum value.
Basically without this change compiler does not know how many threads will be used to run the kernel and so is free to generate code that uses as many registers as it wants. However, the size of register file is not infinite and so the more registers are used by the kernel, the less threads can run at the same time and trying to run the kernel with more threads results in a failure to launch (with too many resources requested for launch error). Since the kernel is launched with cuda::maxThreadsPerBlock threads, we need to make sure that it can be launched on every architecture.

Actually, I just found an issue about spatial transformer giving exactly this error (that was "fixed" by changing the kernel and so luckily being below the threshold again): #11568

Copy link
Member

@anirudh2290 anirudh2290 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ptrendx makes sense, #11568 had two issues of which assert issue was fixed but the too many resources requested for launch was hidden.

@Roshrini
Copy link
Member

Roshrini commented Jan 2, 2019

@samskalicky Can you check if your comment is addressed? Thanks

@samskalicky
Copy link
Contributor

samskalicky commented Jan 2, 2019

@ptrendx Can you add a comment in the code to make it clear that these launch bounds are required to be set this way to avoid the "too many resources requested for launch" error?

Maybe something like:

/*
 * __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
 * This sets the number of threads per block to 1 which
 * reduces the number of register resources needed.
 * Running the kernel with more threads results in a 
 * failure to launch (too many resources requested error).
 */ 

@ptrendx
Copy link
Member Author

ptrendx commented Jan 2, 2019

I added a comment with explanation.

Copy link
Contributor

@samskalicky samskalicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ptrendx for your help with this issue!

@ptrendx
Copy link
Member Author

ptrendx commented Jan 4, 2019

Is there anything else needed for this PR?

@Roshrini
Copy link
Member

Roshrini commented Jan 4, 2019

@anirudh2290 Can you merge this PR if it looks good?
@mxnet-label-bot Update [Operator, pr-awaiting-merge]

@marcoabreu marcoabreu added pr-awaiting-merge Review and CI is complete. Ready to Merge and removed pr-awaiting-review PR is waiting for code review labels Jan 4, 2019
@KellenSunderland
Copy link
Contributor

Looks similar to a few fixes we've provided in the past when we had too many registers for a few kernels to run on a TX1. LGTM.

@KellenSunderland KellenSunderland merged commit 0faa5b7 into apache:master Jan 14, 2019
KellenSunderland pushed a commit to KellenSunderland/incubator-mxnet that referenced this pull request Jan 17, 2019
* Fix launch bounds in spatial transformer

* Adding explanation in comment.
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Fix launch bounds in spatial transformer

* Adding explanation in comment.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-merge Review and CI is complete. Ready to Merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants