diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index 43194db37a3..76eab40de1f 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -81,38 +81,45 @@ class CreateNewDecoderRequests : Algorithm //! @brief Initialize the decoder at `batchSlot` with a new `request`. Exposed only for static batching via //! GptDecoderBatched::newBatch() - static void newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, + static void newRequest(LlmRequest const& llmReq, SharedConstPtr inputIds, SizeType32 batchSlot, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager, runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength); + SizeType32 maxSequenceLength, nvinfer1::DataType logitsType, executor::DecodingConfig const& decodingConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef medusaBuffers); private: //! @brief Setups decoder internal tensors for new speculative decoding request - static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, + static void newRequestSpeculativeDecoding(SizeType32 batchIdx, LlmRequest const& request, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode, - SizeType32 maxDecodingEngineTokens); + SizeType32 maxDecodingEngineTokens, executor::DecodingConfig const& decodingConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef medusaBuffers); //! @brief Setups decoder internal tensors for new request in Draft model Sps mode - static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream); + static void newRequestDraftTokensExternal(SizeType32 batchIdx, LlmRequest const& llmReq, + SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream, + runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + runtime::BufferManager const& bufferManager, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode); //! @brief Setups decoder internal tensors for new Medusa request - static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens); + static void newRequestMedusa(SizeType32 batchIdx, LlmRequest const& llmReq, DecodingInput& jointDecodingInput, + CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers); //! @brief Setups decoder internal tensors for new Lookahead request - static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); + static void newRequestLookahead(SizeType32 batchIdx, DecodingInput& jointDecodingInput, + DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); //! @brief Setups decoder internal tensors for new Explicit draft tokens request - static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, + static void newRequestExplicitDraftTokens(SizeType32 batchIdx, LlmRequest const& llmReq, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); //! @brief Setups decoder internal tensors for new Eagle request - static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); + static void newRequestEagle(SizeType32 batchIdx, LlmRequest const& llmReq, runtime::ModelConfig const& modelConfig, + DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, + executor::DecodingConfig const& decodingConfig); [[nodiscard]] std::tuple, std::vector> @@ -123,9 +130,8 @@ class CreateNewDecoderRequests : Algorithm runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const; - [[nodiscard]] std::shared_ptr retrieveDraftLogits(runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, std::shared_ptr const& tensor, - runtime::BufferManager const& bufferManager) const; + static void retrieveDraftLogits(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + LlmRequest const& llmReq, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode); bool mSpeculativeDecodingFastLogits; bool mIsLeaderInOrchMode; diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index c22855bd604..57a14db6682 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -83,15 +83,13 @@ class DecodingInput TensorConstPtr finishReasons; //! The maximum sequence length for each sequence in the batch, [batchSize] on gpu TensorConstPtr sequenceLimitLength; - TensorConstPtr embeddingBias; // [batchSize, vocabSizePadded] on gpu - TensorConstPtr lengths; // [batchSize, beamWidth] on gpu - std::vector badWordsLists; // [batchSize][2, badWordsLength] on gpu - TensorConstPtr badWordsPtrs; // [batchSize][2, badWordsLength] on pinned - TensorConstPtr badWordsLens; // [batchSize] on gpu - std::vector stopWordsLists; // [batchSize][2, stopWordsLength] on gpu - TensorConstPtr stopWordsPtrs; // [batchSize][2, stopWordsLength] on pinned - TensorConstPtr stopWordsLens; // [batchSize] on pinned - TensorConstPtr noRepeatNgramSize; // [batchSize] on gpu + TensorConstPtr embeddingBias; // [batchSize, vocabSizePadded] on gpu + TensorConstPtr lengths; // [batchSize, beamWidth] on gpu + TensorConstPtr badWordsPtrs; // [batchSize][2, badWordsLength] on pinned + TensorConstPtr badWordsLens; // [batchSize] on gpu + TensorConstPtr stopWordsPtrs; // [batchSize][2, stopWordsLength] on pinned + TensorConstPtr stopWordsLens; // [batchSize] on pinned + TensorConstPtr noRepeatNgramSize; // [batchSize] on gpu //! Parameters for beam search //! KV cache index for beam search, [batchSize, beamWidth, maxSeqLen] on gpu diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index e6b06677d69..020267cf62a 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -160,10 +160,12 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru std::move(lookaheadAlgoConfigs)}; } -void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request, +void CreateNewDecoderRequests::newRequest(LlmRequest const& llmReq, SharedConstPtr inputIds, SizeType32 batchSlot, SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager, runtime::decoder::DecoderState& decoderState, CudaStream const& runtimeStream, CudaStream const& decoderStream, - SizeType32 maxSequenceLength) + SizeType32 maxSequenceLength, nvinfer1::DataType logitsType, executor::DecodingConfig const& decodingConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef medusaBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -178,19 +180,18 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth, tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.", beamWidth, maxBeamWidth)); - auto const& requestIds = request.ids; - auto const inputLength = request.inputLen; - auto const numDecodingEngineTokens = request.generatedTokensPerEngineStep; + auto const& requestIds = inputIds; + auto const inputLength = llmReq.getPromptLen(); + auto const numDecodingEngineTokens = modelConfig.getMaxDecodingTokens(); auto const numDecodingDraftEngineTokens = numDecodingEngineTokens - 1; - auto const maxNewTokens - = request.maxNewTokens.value_or(maxSequenceLength - inputLength - numDecodingDraftEngineTokens); + auto const maxNewTokens = llmReq.mMaxNewTokens; TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens + numDecodingDraftEngineTokens <= maxSequenceLength, tc::fmtstr( "Input length (%d) + max new tokens (%d) + draft tokens (%d) must be less than max sequence length (%d).", inputLength, maxNewTokens, numDecodingDraftEngineTokens, maxSequenceLength)); TLLM_CHECK(requestIds->getDataType() == TRTDataType::value); - auto const endId = request.endId.value_or(-1); + auto const endId = llmReq.mEndId.value_or(-1); // input auto& dJointInput = decoderState.getJointDecodingInput(); @@ -202,23 +203,23 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder runtime::kernels::invokeFill(*endIdTensorPtr, endId, decoderStream); TensorPtr const embeddingBiasSlice = ITensor::slice(constPointerCast(dJointInput.embeddingBias), batchSlot, 1); - if (request.embeddingBias) + auto const embeddingBias = getEmbeddingBias(logitsType, llmReq.getEmbeddingBias().value()); + if (embeddingBias) { - TLLM_CHECK(request.embeddingBias->getShape().nbDims == 2); - TLLM_CHECK(request.embeddingBias->getShape().d[0] == 1); - TLLM_CHECK_WITH_INFO(request.embeddingBias->getShape().d[1] == modelConfig.getVocabSize(), + TLLM_CHECK(embeddingBias->getShape().nbDims == 2); + TLLM_CHECK(embeddingBias->getShape().d[0] == 1); + TLLM_CHECK_WITH_INFO(embeddingBias->getShape().d[1] == modelConfig.getVocabSize(), "The embedding bias shape is not as expected. Expected last dimension to be same as vocab size: %d.", modelConfig.getVocabSize()); - manager.copy(*request.embeddingBias, *embeddingBiasSlice); + manager.copy(*embeddingBias, *embeddingBiasSlice); } else { manager.setZero(*embeddingBiasSlice); } - auto setupWords = [](std::vector& jointWordsLists, TensorPtr const& requestWordsList, - SharedConstPtr& jointWordsPtrs, SharedConstPtr& jointWordsLens, SizeType32& jointMaxWordsLen, - SizeType32 batchSlot) + auto setupWords = [](TensorPtr const& requestWordsList, SharedConstPtr& jointWordsPtrs, + SharedConstPtr& jointWordsLens, SizeType32& jointMaxWordsLen, SizeType32 batchSlot) { if (requestWordsList) { @@ -228,10 +229,6 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder runtime::bufferCast(*constPointerCast(jointWordsLens))[batchSlot] = wordsLen; // FIXME: this is monotonically growing size jointMaxWordsLen = std::max(static_cast(wordsLen), jointMaxWordsLen); - - // NOTE: jointWordsList is not used in gptDecoder, but required to keep WordsList's - // memory allocated - jointWordsLists[batchSlot] = requestWordsList; } else { @@ -239,10 +236,10 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder } }; - setupWords(dJointInput.stopWordsLists, request.stopWordsList, dJointInput.stopWordsPtrs, dJointInput.stopWordsLens, + setupWords(llmReq.getStopWordsList().value(), dJointInput.stopWordsPtrs, dJointInput.stopWordsLens, dJointInput.maxStopWordsLen, batchSlot); - setupWords(dJointInput.badWordsLists, request.badWordsList, dJointInput.badWordsPtrs, dJointInput.badWordsLens, + setupWords(llmReq.getBadWordsList().value(), dJointInput.badWordsPtrs, dJointInput.badWordsLens, dJointInput.maxBadWordsLen, batchSlot); TensorPtr const sequenceLimitLength{ @@ -315,9 +312,10 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder if (numDecodingEngineTokens > 1 || decoderState.getSpeculativeDecodingMode().isDraftTokensExternal()) { TLLM_CHECK(beamWidth == 1); - newRequestSpeculativeDecoding(batchSlot, request, samplingConfig, modelConfig, + newRequestSpeculativeDecoding(batchSlot, llmReq, samplingConfig, modelConfig, worldConfig, bufferManager, decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, decoderStream, - decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens(), decodingConfig, + speculativeDecodingFastLogits, isLeaderInOrchMode, medusaBuffers); } // fill outputIds with endIds @@ -333,11 +331,13 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream, CudaStream const& decoderStream, - SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens) +void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, LlmRequest const& request, + SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, + runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager, + DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, + CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode, + SizeType32 maxDecodingEngineTokens, executor::DecodingConfig const& decodingConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, OptionalRef medusaBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -362,15 +362,20 @@ void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx if (speculativeDecodingMode.isDraftTokensExternal()) { - newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream); + newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream, modelConfig, + worldConfig, bufferManager, speculativeDecodingFastLogits, isLeaderInOrchMode); } else if (speculativeDecodingMode.isMedusa()) { - newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens); + TLLM_CHECK(medusaBuffers); + const_cast(request.mSamplingConfig).topKMedusaHeads + = std::vector>{{medusaBuffers->mTopKs}}; + newRequestMedusa( + batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens, medusaBuffers.value()); } else if (speculativeDecodingMode.isLookaheadDecoding()) { - newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream); + newRequestLookahead(batchIdx, jointDecodingInput, jointDecodingOutput, runtimeStream); } else if (speculativeDecodingMode.isExplicitDraftTokens()) { @@ -378,14 +383,15 @@ void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx } else if (speculativeDecodingMode.isEagle()) { - newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream); + newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream, decodingConfig); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream) +void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx, LlmRequest const& llmReq, + SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream, + runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, + BufferManager const& bufferManager, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -393,18 +399,18 @@ void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx auto& dJointInput = jointDecodingInput; - auto const numDraftTokens = request.generatedTokensPerEngineStep - 1; + auto const numDraftTokens = modelConfig.getMaxDecodingTokens() - 1; - auto const useDraftLogits = request.draftLogits.has_value(); + auto const useDraftLogits = llmReq.getDraftLogits().has_value(); if (useDraftLogits) { - TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); + retrieveDraftLogits(modelConfig, worldConfig, llmReq, speculativeDecodingFastLogits, isLeaderInOrchMode); TensorPtr draftLogitsReqBatchSlice = ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1); draftLogitsReqBatchSlice->squeeze(0); TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); - manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); + manager.copy(*llmReq.getDraftLogits().value(), *draftLogitsReqTokensSlice); } auto* useDraftLogitsHostPtr = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost); useDraftLogitsHostPtr[batchIdx] = useDraftLogits; @@ -417,7 +423,8 @@ void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx = ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1); draftTokensReqBatchSlice->squeeze(0); TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); - TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); + TensorPtr draftTokensView = ITensor::wrap(*llmReq.getDraftTokens(), ITensor::makeShape({numDraftTokens})); + // TODO: This will result in a sync copy unless the std::vector is pinned manager.copy(*draftTokensView, *draftTokensReqTokensSlice); } @@ -437,8 +444,9 @@ void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens) +void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, LlmRequest const& llmReq, + DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens, + MedusaBuffers const& medusaBuffers) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -453,23 +461,26 @@ void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::de runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream); TensorPtr targetTokensPerStepSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); - auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep; + int const generatedTokensPerEngineStep + = llmReq.hasDraftTokens() ? static_cast(llmReq.getDraftTokens()->size()) + 1 : 1; TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens, "Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep, maxDecodingEngineTokens); runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream); TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1); - manager.copy(*request.medusaPaths, *pathsSlice); + auto const medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1); + manager.copy(*medusaPaths, *pathsSlice); TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1); - manager.copy(*request.medusaTreeIds, *treeIdsSlice); + auto const medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1); + manager.copy(*medusaTreeIds, *treeIdsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, DecodingInput& jointDecodingInput, + DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -483,9 +494,8 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime: TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream) +void CreateNewDecoderRequests::newRequestExplicitDraftTokens( + SizeType32 batchIdx, LlmRequest const& llmReq, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -493,13 +503,14 @@ void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx TensorPtr positionIdsBaseSlice = ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1); - runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream); + runtime::kernels::invokeFill(*positionIdsBaseSlice, llmReq.getPromptLen(), runtimeStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, LlmRequest const& llmReq, + runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, + executor::DecodingConfig const& decodingConfig) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -515,8 +526,8 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetCtxRequestTypesHostSlice)[0] = 0; - runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = llmReq.getPromptLen(); + runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = llmReq.getPromptLen(); TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1); @@ -526,19 +537,20 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetGenRequestTypesHostSlice)[0] = 1; - runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = llmReq.getPromptLen(); + runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = llmReq.getPromptLen(); auto const eagleModule = std::dynamic_pointer_cast( modelConfig.getSpeculativeDecodingModulePtr()); std::optional eagleChoicesOpt; - if (request.eagleConfig) + auto const eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig(); + if (eagleConfig) { - eagleChoicesOpt = request.eagleConfig->getEagleChoices(); + eagleChoicesOpt = eagleConfig->getEagleChoices(); } - if (!request.eagleConfig || !request.eagleConfig->useDynamicTree()) + if (!eagleConfig || !eagleConfig->useDynamicTree()) { TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1); TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1); @@ -573,9 +585,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon } inputIds->resize(decoderInputSize); - std::vector decoderRequests; - decoderRequests.reserve(finishedContextRequests.size()); - std::vector lookaheadPrompt; std::vector lookaheadAlgoConfigs; if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) @@ -593,77 +602,14 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon auto const& reqTokens = llmReq->getTokens(0); TLLM_CHECK(reqTokens.size() == static_cast(promptLen)); TensorPtr inputView = ITensor::slice(inputIds, inputOffset, promptLen); + // TODO: This will result in a sync copy unless the std::vector is pinned bufferManager.copy(reqTokens.data(), *inputView); - auto decoderRequest = decoder_batch::Request{inputView, promptLen, llmReq->mMaxNewTokens, llmReq->mEndId}; - llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs; - if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) - { - if (llmReq->hasDraftTokens()) - { - auto const& draftTokens = llmReq->getDraftTokens(); - decoderRequest.draftTokens = bufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); - auto const& draftLogits = llmReq->getDraftLogits(); - if (draftLogits.has_value()) - { - decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), bufferManager); - } - decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; - } - else - { - decoderRequest.generatedTokensPerEngineStep = 1; - } - } - else if (!modelConfig.getSpeculativeDecodingMode().isNone()) - { - decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); - } - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) - { - lookaheadPrompt.emplace_back(ITensor::slice(decoderRequest.ids, 0, decoderRequest.inputLen)); - - auto const& lookaheadRuntimeConfig - = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); - lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); - } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) - { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - if (llmReq->getEmbeddingBias().has_value()) - { - decoderRequest.embeddingBias = getEmbeddingBias(logitsType, llmReq->getEmbeddingBias().value()); - } - if (llmReq->getBadWordsList().has_value()) - { - // Move to GPU and remove leading bs1 dimension since this is what decoderRequest expects - decoderRequest.badWordsList = bufferManager.copyFrom(*llmReq->getBadWordsList().value(), MemoryType::kGPU); - decoderRequest.badWordsList->squeeze(0); - } - if (llmReq->getStopWordsList().has_value()) - { - decoderRequest.stopWordsList - = bufferManager.copyFrom(*llmReq->getStopWordsList().value(), MemoryType::kGPU); - decoderRequest.stopWordsList->squeeze(0); - } - - newRequest(llmReq->mSeqSlot.value(), decoderRequest, llmReq->mSamplingConfig, modelConfig, decoderState, - runtimeStream, decoderStream, maxSequenceLength); - - decoderRequests.push_back(decoderRequest); + newRequest(*llmReq, ITensor::slice(inputIds, inputOffset, promptLen), llmReq->mSeqSlot.value(), + llmReq->mSamplingConfig, modelConfig, worldConfig, bufferManager, decoderState, runtimeStream, + decoderStream, maxSequenceLength, logitsType, decodingConfig, mSpeculativeDecodingFastLogits, + mIsLeaderInOrchMode, medusaBuffers); inputOffset += promptLen; } @@ -671,23 +617,25 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; } -std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig, - WorldConfig const& worldConfig, std::shared_ptr const& tensor, - BufferManager const& bufferManager) const +void CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig, WorldConfig const& worldConfig, + LlmRequest const& llmReq, bool speculativeDecodingFastLogits, bool isLeaderInOrchMode) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (!mSpeculativeDecodingFastLogits) + if (!speculativeDecodingFastLogits) { TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL); + return; } - if (mIsLeaderInOrchMode) + if (isLeaderInOrchMode) { + // TODO: This should be moved out of llm request. te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo; - std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo)); + std::memcpy(&fastLogitsInfo, llmReq.getDraftLogits().value()->data(), sizeof(fastLogitsInfo)); auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value(); + // TODO: This const_cast should be removed, ContextRequests should be non-const reference. + const_cast(llmReq).setDraftLogits(std::move(logits)); // Broadcast to other ranks if needed if (worldConfig.isTensorParallel()) @@ -699,7 +647,7 @@ std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits( commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; + return; } // Get logits from leader rank @@ -707,12 +655,13 @@ std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits( int64_t dims[2]; commSession.bcastValue(dims[0], 0); commSession.bcastValue(dims[1], 0); - auto const logitsDtype = modelConfig.getLogitsDtype(); - auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype); + auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool( + ITensor::makeShape({dims[0], dims[1]}), modelConfig.getLogitsDtype()); commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); + // TODO: Same with this const_cast, ContextRequests should be non-const reference. + const_cast(llmReq).setDraftLogits(std::move(logits)); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index 57c90b43643..7d6834384c5 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -194,8 +194,6 @@ void DecoderState::setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeT dInput.maxLength = mMaxSequenceLength; dInput.maxAttentionWindow = maxAttentionWindow; dInput.sinkTokenLength = sinkTokenLength; - dInput.stopWordsLists.resize(mMaxBatchSize); - dInput.badWordsLists.resize(mMaxBatchSize); auto const maxBatchSizeShape = ITensor::makeShape({mMaxBatchSize}); auto const maxBatchSizeXmaxBeamWidthShape = ITensor::makeShape({mMaxBatchSize, mMaxBeamWidth});