diff --git a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu index a5f22858ac6..faa1f2d9fca 100644 --- a/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu +++ b/cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu @@ -4,7 +4,7 @@ * and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu * Copyright (c) 2024, Tri Dao. * - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) }); } +template +void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream) +{ + bool const isVarlen = params.query_start_loc_ptr != nullptr; + constexpr int kNarrowThreads = 64; + constexpr int kWideThreads = 128; + constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + constexpr int kShortSeqThreshold = kNarrowThreads * kNElts; + // Varlen prefill launches one block per sequence/channel pair, so the per-sequence + // work is usually much smaller than params.seqlen suggests. That path also disables + // the wide vector-load specialization, so the 128-thread kernel tends to overprovision + // threads for many short chunks. Prefer the narrower launch for varlen and for short + // fixed-length inputs; keep the wider launch for long dense sequences. + bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold; + + if (preferNarrowKernel) + { + causal_conv1d_fwd_launch(params, stream); + } + else + { + causal_conv1d_fwd_launch(params, stream); + } +} + template void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) { if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream); } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream); } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream); } }