Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions cpp/tensorrt_llm/kernels/causalConv1d/causalConv1d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
* Copyright (c) 2024, Tri Dao.
*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -349,20 +349,45 @@ void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream)
});
}

template <int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_dispatch(ConvParamsBase& params, cudaStream_t stream)
{
bool const isVarlen = params.query_start_loc_ptr != nullptr;
constexpr int kNarrowThreads = 64;
constexpr int kWideThreads = 128;
constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
constexpr int kShortSeqThreshold = kNarrowThreads * kNElts;
// Varlen prefill launches one block per sequence/channel pair, so the per-sequence
// work is usually much smaller than params.seqlen suggests. That path also disables
// the wide vector-load specialization, so the 128-thread kernel tends to overprovision
// threads for many short chunks. Prefer the narrower launch for varlen and for short
// fixed-length inputs; keep the wider launch for long dense sequences.
bool const preferNarrowKernel = isVarlen || params.seqlen <= kShortSeqThreshold;

if (preferNarrowKernel)
{
causal_conv1d_fwd_launch<kNarrowThreads, kWidth, input_t, weight_t>(params, stream);
}
else
{
causal_conv1d_fwd_launch<kWideThreads, kWidth, input_t, weight_t>(params, stream);
}
}

template <typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream)
{
if (params.width == 2)
{
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<2, input_t, weight_t>(params, stream);
}
else if (params.width == 3)
{
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<3, input_t, weight_t>(params, stream);
}
else if (params.width == 4)
{
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
causal_conv1d_fwd_dispatch<4, input_t, weight_t>(params, stream);
}
}

Expand Down
Loading