Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
aae9688
refactor: Minor cleanup in CreateNewDecoderRequests
Funatiq Aug 1, 2025
7f10367
refactor: remove BufferManager from CreateNewDecoderRequests parameters
Funatiq Aug 1, 2025
92aabea
refactor: Move embedding bias initialization to a separate function i…
Funatiq Aug 1, 2025
a70dec7
refactor: Extract setupWords function in CreateNewDecoderRequests
Funatiq Aug 1, 2025
8fc612d
refactor: Move newRequestSpeculativeDecoding out of newRequest
Funatiq Aug 1, 2025
f846977
refactor: Consolidate speculative decoding logic in createDecoderRequ…
Funatiq Aug 1, 2025
60f266f
refactor: Enhance embedding bias initialization in CreateNewDecoderRe…
Funatiq Aug 1, 2025
c87332c
refactor: Update stopWords and badWords handling in CreateNewDecoderR…
Funatiq Aug 1, 2025
7a3e5e6
refactor: Introduce initializeRequestIds and initializeBeamSearch fun…
Funatiq Aug 1, 2025
c8c94a8
refactor: Remove ids and endId from decoder_batch::Request
Funatiq Aug 1, 2025
2000700
refactor: Introduce initializeLogProbs function in CreateNewDecoderRe…
Funatiq Aug 1, 2025
163a93c
refactor: Use decoder stream for request initialization
Funatiq Aug 1, 2025
efe0aab
refactor: Move beam width check to CreateNewDecoderRequests
Funatiq Aug 1, 2025
4b2df1e
refactor: Introduce initializeInputLengths function in CreateNewDecod…
Funatiq Aug 1, 2025
a4f9c97
refactor: Initialize request lengths in CreateNewDecoderRequests
Funatiq Aug 1, 2025
f34b7e5
refactor: Rename newRequest to initializeOutputs
Funatiq Aug 1, 2025
b4d3a69
refactor: Add setBeamWidth method to DecoderState
Funatiq Aug 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,27 +75,19 @@ class CreateNewDecoderRequests : Algorithm
std::vector<executor::LookaheadDecodingConfig>>
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
SizeType32 maxSequenceLength, SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength,
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;

[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
std::vector<executor::LookaheadDecodingConfig>>
createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream,
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;

private:
//! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via
//! GptDecoderBatched::newBatch()
static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
SizeType32 maxSequenceLength);

//! @brief Setups decoder internal tensors for new speculative decoding request
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/runtime/decoderState.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ class DecoderState
//! @brief Workspace for beam search in streaming mode.
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;

//! @brief Set the beam width for a specific request in the batch.
//! @param batchIdx The index of the request in the batch.
//! @param beamWidth The beam width for the specified request.
void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth);

//! @brief Cache indirection input for beam search.
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;

Expand Down
14 changes: 2 additions & 12 deletions cpp/include/tensorrt_llm/runtime/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,16 @@ class Request
using TensorPtr = ITensor::SharedPtr;
using BufferPtr = IBuffer::SharedPtr;

explicit Request(TensorConstPtr ids, SizeType32 inputLen, std::optional<SizeType32> maxNewTokens = std::nullopt,
std::optional<SizeType32> endId = std::nullopt)
: ids{std::move(ids)}
, inputLen(inputLen)
, maxNewTokens{maxNewTokens}
, endId{endId}
explicit Request(SizeType32 inputLen)
: inputLen(inputLen)
{
}

//! Mandatory parameters
TensorConstPtr ids; // The input sequence of token ids, [inputSeqLen], on gpu
SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps

// optional parameters
std::optional<SizeType32> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType32> endId; // end token id
SizeType32 generatedTokensPerEngineStep{1}; //
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength] on gpu
TensorPtr stopWordsList; // [2, stopWordsLength] on gpu

//! Optional parameters for speculative decoding
BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu
Expand Down
386 changes: 210 additions & 176 deletions cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -1866,9 +1866,9 @@ void TrtGptModelInflightBatching::setupDecoderStep(
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");

auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests,
mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(),
*mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType,
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(),
mOperatingBeamWidth, buffers.mMedusaBuffers);

auto const localBatchSize = batchSlots->getSize();
if (localBatchSize > 0)
Expand Down
16 changes: 7 additions & 9 deletions cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,21 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
"__call__",
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
tensorrt_llm::runtime::CudaStream const& runtimeStream,
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
SizeType32 beamWidth)
{
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);

return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
},
nb::arg("model_config"), nb::arg("world_config"), nb::arg("decoding_config"), nb::arg("context_requests"),
nb::arg("buffer_manager"), nb::arg("logits_type"), nb::arg("decoder_input_buffers"),
nb::arg("decoder_state"), nb::arg("runtime_stream"), nb::arg("decoder_stream"),
nb::arg("max_sequence_length"), nb::arg("beam_width"))
nb::arg("logits_type"), nb::arg("decoder_input_buffers"), nb::arg("decoder_state"),
nb::arg("runtime_stream"), nb::arg("decoder_stream"), nb::arg("max_sequence_length"), nb::arg("beam_width"))
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
}
16 changes: 7 additions & 9 deletions cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,21 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
"__call__",
[](CreateNewDecoderRequests& self, tr::ModelConfig const& modelConfig, tr::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
tr::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
tensorrt_llm::runtime::CudaStream const& runtimeStream,
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
runtime::decoder::DecoderState& decoderState, tensorrt_llm::runtime::CudaStream const& runtimeStream,
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
SizeType32 beamWidth)
{
OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt;
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= self(modelConfig, worldConfig, decodingConfig, contextRequests, logitsType, inputBuffers,
decoderState, runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);

return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
},
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"),
py::arg("decoder_state"), py::arg("runtime_stream"), py::arg("decoder_stream"),
py::arg("max_sequence_length"), py::arg("beam_width"))
py::arg("logits_type"), py::arg("decoder_input_buffers"), py::arg("decoder_state"),
py::arg("runtime_stream"), py::arg("decoder_stream"), py::arg("max_sequence_length"), py::arg("beam_width"))
.def("name", [](CreateNewDecoderRequests const&) { return CreateNewDecoderRequests::name; });
}
5 changes: 5 additions & 0 deletions cpp/tensorrt_llm/runtime/decoderState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationS
mJointDecodingInput->generationSteps = generationSteps;
}

void DecoderState::setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth)
{
mJointDecodingInput->beamWidths.at(batchIdx) = beamWidth;
}

DecodingInput& DecoderState::getJointDecodingInput() const
{
return *mJointDecodingInput;
Expand Down
7 changes: 3 additions & 4 deletions cpp/tests/runtime/gptDecoderBatchedTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,14 @@ void newRequests(std::vector<std::shared_ptr<tb::LlmRequest>> const& requests, T
SizeType32 maxSequenceLength, tb::DecoderInputBuffers& inputBuffers, decoder::DecoderState& decoderState)
{
auto const& decoderStream = *decoder.getDecoderStream();
auto const bufferManager = BufferManager{std::make_shared<CudaStream>(runtimeStream.get())};

auto batchSlotsRange = BufferRange<SizeType32>(*batchSlots);
auto const localBatchSize = batchSlots->getSize();

tb::CreateNewDecoderRequests createNewDecoderRequests(false, false, false);
auto [lookaheadPrompt, lookaheadAlgoConfigs] = createNewDecoderRequests.createDecoderRequests(requests,
inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
runtimeStream, decoderStream, maxSequenceLength, std::nullopt);
auto [lookaheadPrompt, lookaheadAlgoConfigs]
= createNewDecoderRequests.createDecoderRequests(requests, inputBuffers.inputsIds, decodingConfig, decoderState,
logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, std::nullopt);

std::vector<SamplingConfig> samplingConfigs;
samplingConfigs.reserve(requests.size());
Expand Down
3 changes: 1 addition & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,7 @@ def _instantiate_algorithms(self):
def setup_sampler_step(self, requests):
batch_slots, sampling_configs, lookahead_prompt, lookahead_algo_configs = self.algs.create_new_decoder_requests(
self.model_config, self.world_config, self.decoding_config,
requests.context_requests, self.store["buffer_manager"],
self.logits_datatype,
requests.context_requests, self.logits_datatype,
self.store["decoder_input_buffers"][self.micro_batch_idx],
self.store["decoder_state"], self.store["cuda_stream"],
self.algs.decoder.decoder_stream, self.executor_config.max_seq_len,
Expand Down