Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,15 @@ def format_of(ty):
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
void *params[] = {{ {', '.join(params)} }};
if (gridX*gridY*gridZ > 0) {{
CUcontext pctx;
CUDA_CHECK(cuCtxGetCurrent(&pctx));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does it take to get a context? Would it add significant latency?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair question. I ran the launch latency benchmark I wrote a few months back (now part of TritonBench), and it looks like maybe it adds 0.1us (but that's probably below the noise threshold of my box):

Before:

$ python run.py --op launch_latency --only nop_triton_compiled_kernel_run
  x_val    nop_triton_compiled_kernel_run-walltime
-------  -----------------------------------------
      0                                 0.00292743
     19                                 0.00431097

After:

% python run.py --op launch_latency --only nop_triton_compiled_kernel_run
  x_val    nop_triton_compiled_kernel_run-walltime
-------  -----------------------------------------
      0                                 0.00302734
     19                                 0.00438724

if (!pctx) {{
// Ensure device context.
CUdevice device;
CUDA_CHECK(cuDeviceGet(&device, 0));
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
CUDA_CHECK(cuCtxSetCurrent(pctx));
}}
if (num_ctas == 1) {{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}} else {{
Expand Down