-
Notifications
You must be signed in to change notification settings - Fork 995
Vendor CCCL v3.3.2 from GitHub instead of relying on CTK-bundled copy #3091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c2cf070
9fda09f
ac91694
ddd77d1
91fbbb5
3690253
1970b17
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,43 +18,23 @@ limitations under the License. | |
|
|
||
| #include <cuda.h> | ||
|
|
||
| #include <cmath> | ||
| #include <cub/device/device_transform.cuh> | ||
|
|
||
| #include "../../math.cuh" | ||
| #include "../../utils.cuh" | ||
|
|
||
| namespace flashinfer { | ||
|
|
||
| __global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) { | ||
| int elem_idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (elem_idx >= n) return; | ||
| #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.wait;"); | ||
| #endif | ||
| float2 md_elem = md[elem_idx]; | ||
| float m = md_elem.x; | ||
| float d = md_elem.y; | ||
| lse[elem_idx] = math::log2e * m + math::ptx_log2(d); | ||
| #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
| asm volatile("griddepcontrol.launch_dependents;"); | ||
| #endif | ||
| } | ||
| struct MDToLSE { | ||
| __host__ __device__ float operator()(float2 md_elem) const { | ||
| return math::log2e * md_elem.x + log2f(md_elem.y); | ||
| } | ||
| }; | ||
|
|
||
| inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl, | ||
| inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool /*launch_with_pdl*/, | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for reviewers: launch_with_pdl is unused β DeviceTransform enables PDL unconditionally on SM90+ via its internal launcher. On pre-Hopper GPUs, the PDL instructions compile to no-ops. This means callers that pass false will still get PDL when the GPU supports it, which is probably harmless (PDL is a performance hint, not semantic). |
||
| cudaStream_t stream) { | ||
| int num_threads = std::min(1024, UpPowerOfTwo(n)); | ||
| int num_blocks = ceil_div(n, num_threads); | ||
| cudaLaunchConfig_t config; | ||
| config.gridDim = num_blocks; | ||
| config.blockDim = num_threads; | ||
| config.dynamicSmemBytes = 0; | ||
| config.stream = stream; | ||
| cudaLaunchAttribute attrs[1]; | ||
| attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||
| attrs[0].val.programmaticStreamSerializationAllowed = launch_with_pdl; | ||
| config.numAttrs = 1; | ||
| config.attrs = attrs; | ||
|
|
||
| FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n)); | ||
| return cudaSuccess; | ||
| return cub::DeviceTransform::Transform(md, lse, n, MDToLSE{}, stream); | ||
| } | ||
|
|
||
| }; // namespace flashinfer | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.