Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,38 +81,45 @@ class CreateNewDecoderRequests : Algorithm

//! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via
//! GptDecoderBatched::newBatch()
static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
static void newRequest(LlmRequest const& llmReq, SharedConstPtr inputIds, SizeType32 batchSlot,
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager,
runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream,
SizeType32 maxSequenceLength);
SizeType32 maxSequenceLength, nvinfer1::DataType logitsType, executor::DecodingConfig const& decodingConfig,
bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef<MedusaBuffers const> medusaBuffers);

private:
//! @brief Setups decoder internal tensors for new speculative decoding request
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, LlmRequest const& request,
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager,
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode,
SizeType32 maxDecodingEngineTokens);
SizeType32 maxDecodingEngineTokens, executor::DecodingConfig const& decodingConfig,
bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef<MedusaBuffers const> medusaBuffers);

//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream);
static void newRequestDraftTokensExternal(SizeType32 batchIdx, LlmRequest const& llmReq,
SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream,
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
runtime::BufferManager const& bufferManager, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode);

//! @brief Setups decoder internal tensors for new Medusa request
static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens);
static void newRequestMedusa(SizeType32 batchIdx, LlmRequest const& llmReq, DecodingInput& jointDecodingInput,
CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers);

//! @brief Setups decoder internal tensors for new Lookahead request
static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
static void newRequestLookahead(SizeType32 batchIdx, DecodingInput& jointDecodingInput,
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);

//! @brief Setups decoder internal tensors for new Explicit draft tokens request
static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
static void newRequestExplicitDraftTokens(SizeType32 batchIdx, LlmRequest const& llmReq,
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);

//! @brief Setups decoder internal tensors for new Eagle request
static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
static void newRequestEagle(SizeType32 batchIdx, LlmRequest const& llmReq, runtime::ModelConfig const& modelConfig,
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
executor::DecodingConfig const& decodingConfig);

[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
std::vector<executor::LookaheadDecodingConfig>>
Expand All @@ -123,9 +130,8 @@ class CreateNewDecoderRequests : Algorithm
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream,
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;

[[nodiscard]] std::shared_ptr<runtime::ITensor> retrieveDraftLogits(runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
runtime::BufferManager const& bufferManager) const;
static void retrieveDraftLogits(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
LlmRequest const& llmReq, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode);

bool mSpeculativeDecodingFastLogits;
bool mIsLeaderInOrchMode;
Expand Down
16 changes: 7 additions & 9 deletions cpp/include/tensorrt_llm/runtime/decodingInput.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,13 @@ class DecodingInput
TensorConstPtr finishReasons;
//! The maximum sequence length for each sequence in the batch, [batchSize] on gpu
TensorConstPtr sequenceLimitLength;
TensorConstPtr embeddingBias; // [batchSize, vocabSizePadded] on gpu
TensorConstPtr lengths; // [batchSize, beamWidth] on gpu
std::vector<TensorPtr> badWordsLists; // [batchSize][2, badWordsLength] on gpu
TensorConstPtr badWordsPtrs; // [batchSize][2, badWordsLength] on pinned
TensorConstPtr badWordsLens; // [batchSize] on gpu
std::vector<TensorPtr> stopWordsLists; // [batchSize][2, stopWordsLength] on gpu
TensorConstPtr stopWordsPtrs; // [batchSize][2, stopWordsLength] on pinned
TensorConstPtr stopWordsLens; // [batchSize] on pinned
TensorConstPtr noRepeatNgramSize; // [batchSize] on gpu
TensorConstPtr embeddingBias; // [batchSize, vocabSizePadded] on gpu
TensorConstPtr lengths; // [batchSize, beamWidth] on gpu
TensorConstPtr badWordsPtrs; // [batchSize][2, badWordsLength] on pinned
TensorConstPtr badWordsLens; // [batchSize] on gpu
TensorConstPtr stopWordsPtrs; // [batchSize][2, stopWordsLength] on pinned
TensorConstPtr stopWordsLens; // [batchSize] on pinned
TensorConstPtr noRepeatNgramSize; // [batchSize] on gpu

//! Parameters for beam search
//! KV cache index for beam search, [batchSize, beamWidth, maxSeqLen] on gpu
Expand Down
Loading