diff --git a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h b/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h deleted file mode 100644 index 72e0324ca23..00000000000 --- a/cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/batch_manager/kvCacheConfig.h" -#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" -#include "tensorrt_llm/executor/executor.h" -#include "tensorrt_llm/runtime/common.h" - -#include -#include -#include -#include - -namespace tensorrt_llm::batch_manager -{ - -class TrtGptModelOptionalParams -{ - using KvCacheConfig = kv_cache_manager::KvCacheConfig; - -public: - using SizeType32 = tensorrt_llm::runtime::SizeType32; - - // 23 parameters, 23 items in initialization list - explicit TrtGptModelOptionalParams(KvCacheConfig kvCacheConfig = KvCacheConfig{}, bool enableTrtOverlap = false, - std::optional> deviceIds = std::nullopt, bool normalizeLogProbs = true, - bool enableChunkedContext = true, - PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{}, - executor::DecodingConfig decodingConfig = executor::DecodingConfig{}, bool useGpuDirectStorage = false, - float gpuWeightsPercent = 1, std::optional maxBeamWidth = std::nullopt, - std::optional maxBatchSize = std::nullopt, std::optional maxNumTokens = std::nullopt, - executor::SchedulerConfig schedulerConfig = executor::SchedulerConfig{}, - executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig - = executor::ExtendedRuntimePerfKnobConfig{}, - std::optional debugConfig = std::nullopt, - uint64_t maxSeqIdleMicroseconds = executor::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, - std::optional specDecConfig = std::nullopt, - std::optional guidedDecodingConfig = std::nullopt, - bool isLeaderInOrchMode = false, - std::optional> additionalModelOutputs = std::nullopt, - std::optional cacheTransceiverConfig = std::nullopt, - bool gatherGenerationLogits = false, bool promptTableOffloading = false) - : kvCacheConfig{std::move(kvCacheConfig)} - , enableTrtOverlap{enableTrtOverlap} - , deviceIds(std::move(deviceIds)) - , normalizeLogProbs{normalizeLogProbs} - , enableChunkedContext{enableChunkedContext} - , peftCacheManagerConfig(peftCacheManagerConfig) - , decodingConfig(std::move(decodingConfig)) - , useGpuDirectStorage(useGpuDirectStorage) - , gpuWeightsPercent(gpuWeightsPercent) - , maxBeamWidth(maxBeamWidth) - , maxBatchSize(maxBatchSize) - , maxNumTokens(maxNumTokens) - , schedulerConfig{std::move(schedulerConfig)} - , extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig) - , debugConfig{std::move(debugConfig)} - , maxSeqIdleMicroseconds{maxSeqIdleMicroseconds} - , speculativeDecodingConfig{specDecConfig} - , guidedDecodingConfig{std::move(guidedDecodingConfig)} - , isLeaderInOrchMode{isLeaderInOrchMode} - , additionalModelOutputs{std::move(additionalModelOutputs)} - , cacheTransceiverConfig{std::move(cacheTransceiverConfig)} - , gatherGenerationLogits{gatherGenerationLogits} - , promptTableOffloading{promptTableOffloading} - { - if (guidedDecodingConfig) - { - guidedDecodingConfig->validate(); - } - } - - // 2 parameters, 23 items in initialization list - explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode) - : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), - executorConfig.getEnableTrtOverlap(), - executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(), - executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), - PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())), - executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{}), - executorConfig.getUseGpuDirectStorage(), executorConfig.getGpuWeightsPercent(), - executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(), executorConfig.getMaxNumTokens(), - executorConfig.getSchedulerConfig(), executorConfig.getExtendedRuntimePerfKnobConfig(), - executorConfig.getDebugConfig(), executorConfig.getMaxSeqIdleMicroseconds(), - executorConfig.getSpecDecConfig(), executorConfig.getGuidedDecodingConfig(), isLeaderInOrchMode, - executorConfig.getAdditionalModelOutputs(), executorConfig.getCacheTransceiverConfig(), - executorConfig.getGatherGenerationLogits(), executorConfig.getPromptTableOffloading()) - { - } - - friend std::ostream& operator<<(std::ostream& os, TrtGptModelOptionalParams const& self); - - KvCacheConfig kvCacheConfig; - - bool enableTrtOverlap; - std::optional> deviceIds; - bool normalizeLogProbs; - bool enableChunkedContext; - PeftCacheManagerConfig peftCacheManagerConfig; - executor::DecodingConfig decodingConfig; - // Use GDS to load the engines? - bool useGpuDirectStorage; - // Percentage of weights on the gpu at runtime - float gpuWeightsPercent; - std::optional maxBeamWidth; - std::optional maxBatchSize; - std::optional maxNumTokens; - executor::SchedulerConfig schedulerConfig; - executor::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig; - std::optional debugConfig; - // Sequence is considered idle if not updated for this amount of time. - uint64_t maxSeqIdleMicroseconds; - std::optional speculativeDecodingConfig; - std::optional guidedDecodingConfig; - // This rank is the leader worker in orchestrator mode - bool isLeaderInOrchMode; - std::optional> additionalModelOutputs; - std::optional cacheTransceiverConfig; - bool gatherGenerationLogits; - // Whether to offload the prompt table to CPU and prefetching to GPU - bool promptTableOffloading; -}; - -} // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index e6b06677d69..7085da70b15 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -25,7 +25,6 @@ #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" -#include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" diff --git a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp index 980423d7d8e..de1525b0773 100644 --- a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp @@ -22,6 +22,7 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/tllmRuntime.h" @@ -39,14 +40,14 @@ namespace tensorrt_llm::batch_manager TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig, runtime::RawEngine const& rawEngine, std::shared_ptr logger, - TrtGptModelOptionalParams const& optionalParams) - : TrtGptModel(modelConfig, worldConfig, optionalParams) + executor::ExecutorConfig const& executorConfig) + : TrtGptModel(modelConfig, worldConfig, executorConfig) , mModelConfig{modelConfig} , mWorldConfig{worldConfig} , mDevice{runtime::utils::initDevice(worldConfig)} , mLogger{logger ? std::move(logger) : std::make_shared()} , mRuntime{std::make_shared( - rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, optionalParams.gpuWeightsPercent)} + rawEngine, mLogger.get(), executorConfig.getUseGpuDirectStorage(), executorConfig.getGpuWeightsPercent())} , mNumMicroBatches{1} , mNumBuffers{mNumMicroBatches} , mCopyBufferManager{std::make_shared()} @@ -75,8 +76,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC // handling of maximizing utilization or pause/evict // TODO: finer control on encoder requests scheduling mCapacityScheduler = std::make_unique( - getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false, - LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); + getMaxBatchSize() * mNumMicroBatches, executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), false, + false, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); mMicroBatchScheduler = std::make_unique( std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT); diff --git a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.h b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.h index f7d2750262c..31f7d3d0c89 100644 --- a/cpp/tensorrt_llm/batch_manager/trtEncoderModel.h +++ b/cpp/tensorrt_llm/batch_manager/trtEncoderModel.h @@ -17,7 +17,6 @@ #pragma once -#include "tensorrt_llm/runtime/iGptDecoderBatched.h" #include "tensorrt_llm/runtime/rawEngine.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include "trtGptModel.h" @@ -47,7 +46,7 @@ class TrtEncoderModel : public TrtGptModel TrtEncoderModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::RawEngine const& rawEngine, std::shared_ptr logger, - TrtGptModelOptionalParams const& optionalParams); + executor::ExecutorConfig const& executorConfig); ~TrtEncoderModel() override; diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModel.h b/cpp/tensorrt_llm/batch_manager/trtGptModel.h index f8ab501a5d8..eb1a815a683 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModel.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModel.h @@ -18,7 +18,6 @@ #pragma once #include "tensorrt_llm/batch_manager/peftCacheManager.h" -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/stlUtils.h" #include "tensorrt_llm/executor/executor.h" @@ -52,15 +51,15 @@ class TrtGptModel : public executor::Model using SizeType32 = tensorrt_llm::runtime::SizeType32; TrtGptModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, - TrtGptModelOptionalParams const& optionalParams) - : mMaxBatchSize{optionalParams.maxBatchSize.value_or(modelConfig.getMaxBatchSize())} - , mMaxBeamWidth{optionalParams.maxBeamWidth.value_or(modelConfig.getMaxBeamWidth())} + executor::ExecutorConfig const& executorConfig) + : mMaxBatchSize{executorConfig.getMaxBatchSize().value_or(modelConfig.getMaxBatchSize())} + , mMaxBeamWidth{executorConfig.getMaxBeamWidth()} , mMaxSequenceLen{modelConfig.getMaxSequenceLen()} , mMaxDraftLen{modelConfig.getMaxDecodingDraftTokens()} , mVocabSizePadded{modelConfig.getVocabSizePadded(worldConfig.getSize())} - , mNormalizeLogProbs{optionalParams.normalizeLogProbs} - , mEnableTrtOverlap{optionalParams.enableTrtOverlap} - , mCudaGraphMode{optionalParams.extendedRuntimePerfKnobConfig.getCudaGraphMode()} + , mNormalizeLogProbs{executorConfig.getNormalizeLogProbs()} + , mEnableTrtOverlap{executorConfig.getEnableTrtOverlap()} + , mCudaGraphMode{executorConfig.getExtendedRuntimePerfKnobConfig().getCudaGraphMode()} { TLLM_CHECK_WITH_INFO(mMaxBeamWidth <= modelConfig.getMaxBeamWidth(), "Runtime configured max beam width (%d) must not exceed engine max beam width (%d)", mMaxBeamWidth, @@ -68,7 +67,7 @@ class TrtGptModel : public executor::Model TLLM_CHECK_WITH_INFO(mMaxBatchSize <= modelConfig.getMaxBatchSize(), "Runtime configured max batch size (%d) must not exceed engine max batch size (%d)", mMaxBatchSize, modelConfig.getMaxBatchSize()); - if (optionalParams.enableTrtOverlap) + if (executorConfig.getEnableTrtOverlap()) { if (mMaxBeamWidth > 1) { @@ -85,10 +84,11 @@ class TrtGptModel : public executor::Model } mMaxAttentionWindow = 0; - if (optionalParams.kvCacheConfig.maxAttentionWindowVec.has_value()) + if (executorConfig.getKvCacheConfig().getMaxAttentionWindowVec().has_value()) { bool warning = false; - for (int maxAttenWin : optionalParams.kvCacheConfig.maxAttentionWindowVec.value()) + auto const& maxAttentionWindowVec = executorConfig.getKvCacheConfig().getMaxAttentionWindowVec(); + for (int maxAttenWin : maxAttentionWindowVec.value()) { mMaxAttentionWindowVec.push_back(std::min(maxAttenWin, mMaxSequenceLen)); mMaxAttentionWindow = std::max(mMaxAttentionWindow, mMaxAttentionWindowVec.back()); @@ -112,8 +112,8 @@ class TrtGptModel : public executor::Model mMaxAttentionWindow = mMaxSequenceLen; } - mSinkTokenLen = optionalParams.kvCacheConfig.sinkTokenLength.has_value() - ? optionalParams.kvCacheConfig.sinkTokenLength.value() + mSinkTokenLen = executorConfig.getKvCacheConfig().getSinkTokenLength().has_value() + ? executorConfig.getKvCacheConfig().getSinkTokenLength().value() : 0; mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism(); @@ -136,18 +136,18 @@ class TrtGptModel : public executor::Model TLLM_LOG_INFO("TRTGptModel normalizeLogProbs: %d", mNormalizeLogProbs); mMaxNumTokens = modelConfig.getMaxNumTokens(); - if (optionalParams.maxNumTokens && mMaxNumTokens) + if (executorConfig.getMaxNumTokens().has_value() && mMaxNumTokens) { - if (optionalParams.maxNumTokens.value() > mMaxNumTokens.value()) + if (executorConfig.getMaxNumTokens().value() > mMaxNumTokens.value()) { TLLM_LOG_WARNING( "Runtime configured max num tokens (%d) is larger than model max num tokens (%d) and will be " "ignored.", - optionalParams.maxNumTokens.value(), mMaxNumTokens.value()); + executorConfig.getMaxNumTokens().value(), mMaxNumTokens.value()); } else { - mMaxNumTokens = optionalParams.maxNumTokens; + mMaxNumTokens = executorConfig.getMaxNumTokens(); } } if (mMaxNumTokens) @@ -155,7 +155,7 @@ class TrtGptModel : public executor::Model TLLM_LOG_INFO("TRTGptModel maxNumTokens: %d", mMaxNumTokens.value()); } - if (optionalParams.enableChunkedContext) + if (executorConfig.getEnableChunkedContext()) { mMaxInputLen = mMaxSequenceLen - 1; TLLM_LOG_INFO( @@ -199,9 +199,9 @@ class TrtGptModel : public executor::Model using tensorrt_llm::common::stl_utils::toString; TLLM_LOG_INFO("Capacity Scheduler Policy: %s", - toString(optionalParams.schedulerConfig.getCapacitySchedulerPolicy()).c_str()); + toString(executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy()).c_str()); TLLM_LOG_INFO("Context Chunking Scheduler Policy: %s", - toString(optionalParams.schedulerConfig.getContextChunkingPolicy()).c_str()); + toString(executorConfig.getSchedulerConfig().getContextChunkingPolicy()).c_str()); } [[nodiscard]] std::optional getMaxNumTokens() const diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelFactory.h b/cpp/tensorrt_llm/batch_manager/trtGptModelFactory.h index 189d615ab14..bd4d7c76737 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelFactory.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelFactory.h @@ -17,12 +17,14 @@ #pragma once +#include "tensorrt_llm/batch_manager/trtGptModel.h" +#include "tensorrt_llm/batch_manager/trtGptModelInflightBatching.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/rawEngine.h" #include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/worldConfig.h" -#include "trtGptModelInflightBatching.h" #include @@ -38,28 +40,31 @@ class TrtGptModelFactory using SizeType32 = tensorrt_llm::runtime::SizeType32; static std::shared_ptr create(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, - TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams()) + executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode) { auto const jsonConfig = runtime::GptJsonConfig::parse(trtEnginePath / "config.json"); - auto worldConfig = getWorldConfig(jsonConfig, optionalParams.deviceIds); + auto const& deviceIds = executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(); + auto const worldConfig = getWorldConfig(jsonConfig, deviceIds); auto const enginePath = trtEnginePath / jsonConfig.engineFilename(worldConfig); auto const& modelConfig = jsonConfig.getModelConfig(); - return create(runtime::RawEngine(enginePath), modelConfig, worldConfig, modelType, optionalParams); + return create( + runtime::RawEngine(enginePath), modelConfig, worldConfig, modelType, executorConfig, isLeaderInOrchMode); } static std::shared_ptr create(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, runtime::GptJsonConfig const& jsonConfig, runtime::WorldConfig const& worldConfig, - TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams()) + executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode) { auto const enginePath = trtEnginePath / jsonConfig.engineFilename(worldConfig); auto const& modelConfig = jsonConfig.getModelConfig(); - return create(runtime::RawEngine(enginePath), modelConfig, worldConfig, modelType, optionalParams); + return create( + runtime::RawEngine(enginePath), modelConfig, worldConfig, modelType, executorConfig, isLeaderInOrchMode); } static std::shared_ptr create(runtime::RawEngine const& rawEngine, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, TrtGptModelType modelType, - TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams()) + executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode) { auto logger = std::make_shared(); auto const device = worldConfig.getDevice(); @@ -69,12 +74,13 @@ class TrtGptModelFactory if ((modelType == TrtGptModelType::InflightBatching) || (modelType == TrtGptModelType::InflightFusedBatching)) { - TrtGptModelOptionalParams const& fixedOptionalParams - = TrtGptModelInflightBatching::optionalParamsAreValid(modelConfig, optionalParams) - ? optionalParams - : TrtGptModelInflightBatching::fixOptionalParams(modelConfig, optionalParams); - return std::make_shared(logger, modelConfig, worldConfig, rawEngine, - (modelType == TrtGptModelType::InflightFusedBatching), fixedOptionalParams); + executor::ExecutorConfig const& fixedExecutorConfig + = TrtGptModelInflightBatching::executorConfigIsValid(modelConfig, executorConfig) + ? executorConfig + : TrtGptModelInflightBatching::fixExecutorConfig(modelConfig, executorConfig); + bool const ctxGenFusion = modelType == TrtGptModelType::InflightFusedBatching; + return std::make_shared( + logger, modelConfig, worldConfig, rawEngine, ctxGenFusion, fixedExecutorConfig, isLeaderInOrchMode); } throw std::runtime_error("Invalid modelType in trtGptModelFactory"); diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index b84ef6c48c2..a5f9d27937d 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -40,6 +40,7 @@ #include "tensorrt_llm/batch_manager/promptTuningBuffers.h" #include "tensorrt_llm/batch_manager/rnnStateManager.h" #include "tensorrt_llm/batch_manager/runtimeBuffers.h" +#include "tensorrt_llm/batch_manager/sequenceSlotManager.h" #include "tensorrt_llm/batch_manager/transformerBuffers.h" #include "tensorrt_llm/batch_manager/updateDecoderBuffers.h" #include "tensorrt_llm/batch_manager/utils/debugUtils.h" @@ -79,74 +80,79 @@ using namespace tensorrt_llm::runtime; namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; -namespace texe = tensorrt_llm::executor; namespace tensorrt_llm::batch_manager { -bool TrtGptModelInflightBatching::optionalParamsAreValid( - ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams) +bool TrtGptModelInflightBatching::executorConfigIsValid( + ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig) { - // Make sure logic in this function matches fixOptionalParams - if (optionalParams.kvCacheConfig.enableBlockReuse) + // Make sure logic in this function matches fixExecutorConfig + if (executorConfig.getKvCacheConfig().getEnableBlockReuse()) { if (!modelConfig.getPagedContextFMHA()) { return false; } - } - // Context logits cannot be returned for reused tokens, so disable reuse - if (modelConfig.computeContextLogits()) - { - return false; + // Context logits cannot be returned for reused tokens, so disable reuse + if (modelConfig.computeContextLogits()) + { + return false; + } } return true; } -TrtGptModelOptionalParams TrtGptModelInflightBatching::fixOptionalParams( - ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams) +executor::ExecutorConfig TrtGptModelInflightBatching::fixExecutorConfig( + ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig) { - // Make sure logic in this function matches optionalParamsAreValid - auto fixedOptionalParams = TrtGptModelOptionalParams(optionalParams); - if (fixedOptionalParams.kvCacheConfig.enableBlockReuse) + // Make sure logic in this function matches executorConfigIsValid + if (executorConfig.getKvCacheConfig().getEnableBlockReuse()) { + auto kvCacheConfig = executorConfig.getKvCacheConfig(); + if (!modelConfig.getPagedContextFMHA()) { TLLM_LOG_WARNING( - "Fix optionalParams : KV cache reuse disabled because model was not built with paged context FMHA " + "Fixing executorConfig: KV cache reuse disabled because model was not built with paged context FMHA " "support"); - fixedOptionalParams.kvCacheConfig.enableBlockReuse = false; + kvCacheConfig.setEnableBlockReuse(false); } + if (modelConfig.computeContextLogits()) + { + TLLM_LOG_WARNING( + "Fixing executorConfig: KV cache reuse disabled because model was built to return context logits"); + kvCacheConfig.setEnableBlockReuse(false); + } + + auto fixedExecutorConfig = executor::ExecutorConfig(executorConfig); + fixedExecutorConfig.setKvCacheConfig(kvCacheConfig); + return fixedExecutorConfig; } - if (modelConfig.computeContextLogits()) - { - TLLM_LOG_WARNING( - "Fix optionalParams : KV cache reuse disabled because model was built to return context logits"); - fixedOptionalParams.kvCacheConfig.enableBlockReuse = false; - } - return fixedOptionalParams; + return executorConfig; } TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr logger, ModelConfig const& modelConfig, WorldConfig const& worldConfig, RawEngine const& rawEngine, bool ctxGenFusion, - TrtGptModelOptionalParams const& optionalParams) - : TrtGptModel(modelConfig, worldConfig, optionalParams) + executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode) + : TrtGptModel(modelConfig, worldConfig, executorConfig) , mModelConfig(modelConfig) , mWorldConfig(worldConfig) , mDevice{runtime::utils::initDevice(worldConfig)} - , mDecodingConfig{optionalParams.decodingConfig} - , mExtendedRuntimePerfKnobConfig{optionalParams.extendedRuntimePerfKnobConfig} - , mDebugConfig{optionalParams.debugConfig} - , mAdditionalModelOutputs{worldConfig.isLastPipelineParallelRank() ? optionalParams.additionalModelOutputs + , mDecodingConfig{executorConfig.getDecodingConfig().value_or(executor::DecodingConfig{})} + , mExtendedRuntimePerfKnobConfig{executorConfig.getExtendedRuntimePerfKnobConfig()} + , mDebugConfig{executorConfig.getDebugConfig()} + , mAdditionalModelOutputs{worldConfig.isLastPipelineParallelRank() ? executorConfig.getAdditionalModelOutputs() : std::nullopt} , mLogger{logger ? std::move(logger) : std::make_shared()} - , mRuntime{std::make_unique(rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, - optionalParams.gpuWeightsPercent, modelConfig.useShapeInference())} + , mRuntime{std::make_unique(rawEngine, mLogger.get(), executorConfig.getUseGpuDirectStorage(), + executorConfig.getGpuWeightsPercent(), modelConfig.useShapeInference())} , mCopyBufferManager{std::make_shared()} , mCtxGenFusion(ctxGenFusion) , mOperatingBeamWidth{getMaxBeamWidth()} - , mGatherGenerationLogits{optionalParams.gatherGenerationLogits} - , mPromptTableOffloading{optionalParams.promptTableOffloading} + , mGatherGenerationLogits{executorConfig.getGatherGenerationLogits()} + , mPromptTableOffloading{executorConfig.getPromptTableOffloading()} + , mIsLeaderInOrchMode{isLeaderInOrchMode} { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); @@ -175,7 +181,9 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr(optionalParams.guidedDecodingConfig.value(), + mGuidedDecoder = std::make_unique(executorConfig.getGuidedDecodingConfig().value(), getMaxNumSequences(), mModelConfig.getVocabSizePadded(mWorldConfig.getSize()), mModelConfig.getLogitsDtype(), mRuntime->getBufferManager()); } @@ -241,8 +249,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr( - optionalParams.peftCacheManagerConfig, mModelConfig, mWorldConfig, mRuntime->getBufferManager()); + peftCacheManagerConfig, mModelConfig, mWorldConfig, mRuntime->getBufferManager()); } else { @@ -256,38 +266,38 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrgetBufferManager(), optionalParams.kvCacheConfig); + = BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig); if (mModelConfig.useCrossAttention()) { - TLLM_CHECK_WITH_INFO(optionalParams.kvCacheConfig.crossKvCacheFraction.has_value(), + TLLM_CHECK_WITH_INFO(kvCacheConfig.crossKvCacheFraction.has_value(), "Must set crossKvCacheFraction for encoder-decoder model"); - auto const crossKvCacheFraction = optionalParams.kvCacheConfig.crossKvCacheFraction.value(); - mKvCacheManager = createKvCacheManager(optionalParams.kvCacheConfig, KvCacheType::kSELF, + auto const crossKvCacheFraction = kvCacheConfig.crossKvCacheFraction.value(); + mKvCacheManager = createKvCacheManager(kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes * (1.0f - crossKvCacheFraction), freeSecondaryMemBytes * (1.0f - crossKvCacheFraction), cacheTransPreAllocaSize); - mCrossKvCacheManager = createKvCacheManager(optionalParams.kvCacheConfig, KvCacheType::kCROSS, - freePrimaryMemBytes * crossKvCacheFraction, freeSecondaryMemBytes * crossKvCacheFraction, - cacheTransPreAllocaSize); + mCrossKvCacheManager + = createKvCacheManager(kvCacheConfig, KvCacheType::kCROSS, freePrimaryMemBytes * crossKvCacheFraction, + freeSecondaryMemBytes * crossKvCacheFraction, cacheTransPreAllocaSize); TLLM_LOG_INFO("This is an Encoder-Decoder model, set %0.1f cross KV cache fraction based on the config.", crossKvCacheFraction); } else { - TLLM_CHECK_WITH_INFO(!optionalParams.kvCacheConfig.crossKvCacheFraction.has_value(), + TLLM_CHECK_WITH_INFO(!kvCacheConfig.crossKvCacheFraction.has_value(), "Do not set crossKvCacheFraction for decoder-only model"); - mKvCacheManager = createKvCacheManager(optionalParams.kvCacheConfig, KvCacheType::kSELF, - freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize); + mKvCacheManager = createKvCacheManager( + kvCacheConfig, KvCacheType::kSELF, freePrimaryMemBytes, freeSecondaryMemBytes, cacheTransPreAllocaSize); } mCacheTransceiver = CacheTransceiverFactory::createCacheTransceiver(mKvCacheManager.get(), mModelConfig, mWorldConfig, - executor::kv_cache::CacheState::AttentionType::kDEFAULT, optionalParams.cacheTransceiverConfig); + executor::kv_cache::CacheState::AttentionType::kDEFAULT, executorConfig.getCacheTransceiverConfig()); } if (mModelConfig.getSpeculativeDecodingMode().needsKVCacheRewind()) @@ -339,7 +349,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr(getMaxNumSequences(), optionalParams.maxSeqIdleMicroseconds); + = std::make_shared(getMaxNumSequences(), executorConfig.getMaxSeqIdleMicroseconds()); mMicroBatchScheduledRequests.resize(mNumMicroBatches); mDecoderFinishedEvents.resize(mNumMicroBatches); @@ -349,13 +359,13 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr ctxChunkConfig; - if (optionalParams.enableChunkedContext) + if (executorConfig.getEnableChunkedContext()) { TLLM_CHECK_WITH_INFO(modelConfig.isKVCacheEnabled() && mModelConfig.getPagedContextFMHA(), "Chunked context requires context FMHA, paged kv_cache and paged context FMHA all enabled at the same " @@ -368,10 +378,10 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr(mModelConfig.getMaxInputLen()); @@ -398,7 +408,7 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr(getMaxNumSequences(), - optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), mKvCacheManager != nullptr, + executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), mKvCacheManager != nullptr, mWorldConfig.isPipelineParallel()); mMicroBatchScheduler = std::make_unique(ctxChunkConfig, maxContextLength); @@ -438,9 +448,8 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrfastLogits; if (mSpeculativeDecodingFastLogits && modelConfig.getSpeculativeDecodingMode().isNone() && mIsLeaderInOrchMode) { mDraftModelSendLogitsThread = std::make_unique(&utils::draftModelSendLogitsThread, mDevice, diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h index 831ada4f510..6a0ebcd9dc7 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h @@ -18,10 +18,9 @@ #pragma once #include "tensorrt_llm/batch_manager/common.h" -#include "tensorrt_llm/batch_manager/sequenceSlotManager.h" +#include "tensorrt_llm/batch_manager/kvCacheConfig.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/types.h" -#include "tensorrt_llm/runtime/gptDecoderBatched.h" #include "tensorrt_llm/runtime/modelConfig.h" #include "tensorrt_llm/runtime/rawEngine.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -37,6 +36,18 @@ class GptDecoderBatched; class AllReduceBuffers; class NcclCommunicator; class SpeculativeDecodingMode; + +namespace decoder +{ +class DecoderState; +} // namespace decoder + +namespace decoder_batch +{ +class Input; +class Output; +} // namespace decoder_batch + } // namespace tensorrt_llm::runtime namespace tensorrt_llm::mpi @@ -57,6 +68,7 @@ namespace rnn_state_manager { class RnnStateManager; } // namespace rnn_state_manager + class SequenceSlotManager; class DecoderStepAsyncSend; class DecoderSlotAsyncSend; @@ -133,7 +145,7 @@ class TrtGptModelInflightBatching : public TrtGptModel TrtGptModelInflightBatching(std::shared_ptr logger, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, runtime::RawEngine const& rawEngine, bool ctxGenFusion, - TrtGptModelOptionalParams const& optionalParams = TrtGptModelOptionalParams()); + executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode); ~TrtGptModelInflightBatching() override; @@ -191,10 +203,10 @@ class TrtGptModelInflightBatching : public TrtGptModel return mIterCounter; } - [[nodiscard]] static bool optionalParamsAreValid( - runtime::ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams); - [[nodiscard]] static TrtGptModelOptionalParams fixOptionalParams( - runtime::ModelConfig const& modelConfig, TrtGptModelOptionalParams const& optionalParams); + [[nodiscard]] static bool executorConfigIsValid( + runtime::ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig); + [[nodiscard]] static executor::ExecutorConfig fixExecutorConfig( + runtime::ModelConfig const& modelConfig, executor::ExecutorConfig const& executorConfig); void prepareDisaggGenInitRequests(RequestList const& activeRequests, RequestVector& newGenReques); void checkDisaggGenTransferStatus(RequestList const& activeRequests); diff --git a/cpp/tensorrt_llm/executor/executorImpl.cpp b/cpp/tensorrt_llm/executor/executorImpl.cpp index 139cc8a9bc6..d26526d9d97 100644 --- a/cpp/tensorrt_llm/executor/executorImpl.cpp +++ b/cpp/tensorrt_llm/executor/executorImpl.cpp @@ -16,10 +16,8 @@ */ #include "tensorrt_llm/executor/executorImpl.h" -#include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/batch_manager/trtEncoderModel.h" #include "tensorrt_llm/batch_manager/trtGptModelFactory.h" -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaProfilerUtils.h" #include "tensorrt_llm/common/logger.h" @@ -54,21 +52,37 @@ namespace tensorrt_llm::executor namespace { -void checkOptionalParams( - batch_manager::TrtGptModelOptionalParams& optionalParams, runtime::ModelConfig const& modelConfig) +[[nodiscard]] bool executorConfigIsValid(ExecutorConfig const& executorConfig, runtime::ModelConfig const& modelConfig) { + // Make sure logic in this function matches fixExecutorConfig + if (executorConfig.getEnableChunkedContext()) + { + if (modelConfig.isRnnBased() || !modelConfig.isKVCacheEnabled() || !modelConfig.getPagedContextFMHA()) + { + return false; + } + } + return true; +} + +[[nodiscard]] ExecutorConfig fixExecutorConfig( + ExecutorConfig const& executorConfig, runtime::ModelConfig const& modelConfig) +{ + // Make sure logic in this function matches executorConfigIsValid + auto fixedExecutorConfig = executorConfig; // Disable chunked context when not supported - if (optionalParams.enableChunkedContext) + if (executorConfig.getEnableChunkedContext()) { if (modelConfig.isRnnBased() || !modelConfig.isKVCacheEnabled() || !modelConfig.getPagedContextFMHA()) { - optionalParams.enableChunkedContext = false; + fixedExecutorConfig.setEnableChunkedContext(false); TLLM_LOG_WARNING( "Chunked context is not supported for this configuration and will be disabled. " "Related configs: RNNBased: %d, KVCacheEnabled: %d, PagedContextFMHA: %d", modelConfig.isRnnBased(), modelConfig.isKVCacheEnabled(), modelConfig.getPagedContextFMHA()); } } + return fixedExecutorConfig; } SizeType32 getNumChildRequests(Request const& request) @@ -488,19 +502,22 @@ std::shared_ptr Executor::Impl::createModel(runtime::RawEngine const& raw }(); bool const isLeaderInOrchMode = (mCommMode == CommunicationMode::kORCHESTRATOR) && mIsLeader; - auto optionalParams = batch_manager::TrtGptModelOptionalParams(executorConfig, isLeaderInOrchMode); - checkOptionalParams(optionalParams, modelConfig); - return batch_manager::TrtGptModelFactory::create(rawEngine, modelConfig, worldConfig, gptModelType, optionalParams); + auto const& fixedExecutorConfig = executorConfigIsValid(executorConfig, modelConfig) + ? executorConfig + : fixExecutorConfig(executorConfig, modelConfig); + + return batch_manager::TrtGptModelFactory::create( + rawEngine, modelConfig, worldConfig, gptModelType, fixedExecutorConfig, isLeaderInOrchMode); } std::shared_ptr Executor::Impl::createEncoderModel(runtime::RawEngine const& rawEngine, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig, ExecutorConfig const& executorConfig) { - auto optionalParams = batch_manager::TrtGptModelOptionalParams{}; - optionalParams.schedulerConfig = executorConfig.getSchedulerConfig(); + auto fixedExecutorConfig = ExecutorConfig{}; + fixedExecutorConfig.setSchedulerConfig(executorConfig.getSchedulerConfig()); return std::make_shared( - modelConfig, worldConfig, rawEngine, std::make_shared(), optionalParams); + modelConfig, worldConfig, rawEngine, std::make_shared(), fixedExecutorConfig); } void Executor::Impl::setOrchLeaderComm( diff --git a/cpp/tensorrt_llm/pybind/bindings.cpp b/cpp/tensorrt_llm/pybind/bindings.cpp index ebda5773abb..ad6018fdd41 100644 --- a/cpp/tensorrt_llm/pybind/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/bindings.cpp @@ -24,7 +24,7 @@ #include #include "tensorrt_llm/batch_manager/kvCacheConfig.h" -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" +#include "tensorrt_llm/batch_manager/peftCacheManagerConfig.h" #include "tensorrt_llm/common/quantization.h" #include "tensorrt_llm/pybind/batch_manager/algorithms.h" #include "tensorrt_llm/pybind/batch_manager/bindings.h" @@ -40,7 +40,6 @@ #include "tensorrt_llm/runtime/cudaStream.h" #include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" -#include "tensorrt_llm/runtime/ipcUtils.h" #include "tensorrt_llm/runtime/memoryCounters.h" #include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -511,43 +510,6 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m) .value("DISAGG_GENERATION_TRANS_COMPLETE", tb::LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE) .value("DISAGG_CONTEXT_INIT_AND_TRANS", tb::LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS); - auto gptModelParamsGetState = [&kvCacheConfigGetState](tb::TrtGptModelOptionalParams const& params) - { - auto kvCacheState = kvCacheConfigGetState(params.kvCacheConfig); - return py::make_tuple(kvCacheState, params.enableTrtOverlap, params.deviceIds, params.normalizeLogProbs, - params.enableChunkedContext, params.decodingConfig.getDecodingMode()); - }; - auto gptModelParamsSetState = [&kvCacheConfigSetState](py::tuple t) - { - auto kvCacheConfig = kvCacheConfigSetState(t[0]); - return tb::TrtGptModelOptionalParams(kvCacheConfig, t[1].cast(), - t[2].cast>>(), t[3].cast(), t[4].cast(), - tb::PeftCacheManagerConfig{}, - tensorrt_llm::executor::DecodingConfig(t[5].cast>())); - }; - - py::class_(m, "TrtGptModelOptionalParams") - .def(py::init> const&, bool, bool, - tb::PeftCacheManagerConfig const&>(), - py::arg_v("kv_cache_config", tbk::KvCacheConfig{}, "KvCacheConfig()"), - py::arg("enable_trt_overlap") = false, py::arg("device_ids") = std::nullopt, - py::arg("normalize_log_probs") = true, py::arg("enable_chunked_context") = false, - py::arg_v("peft_cache_manager_config", tb::PeftCacheManagerConfig{}, "PeftCacheManagerConfig()")) - .def(py::init(), py::arg("executor_config"), - py::arg("is_leader_in_orch_mode") = false) - .def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig) - .def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap) - .def_readwrite("device_ids", &tb::TrtGptModelOptionalParams::deviceIds) - .def_readwrite("enable_chunked_context", &tb::TrtGptModelOptionalParams::enableChunkedContext) - .def_readwrite("normalize_log_probs", &tb::TrtGptModelOptionalParams::normalizeLogProbs) - .def_readwrite("decoding_config", &tb::TrtGptModelOptionalParams::decodingConfig) - .def_readwrite("use_gpu_direct_storage", &tb::TrtGptModelOptionalParams::useGpuDirectStorage) - .def_readwrite("gpu_weights_percent", &tb::TrtGptModelOptionalParams::gpuWeightsPercent) - .def_readwrite("max_beam_width", &tb::TrtGptModelOptionalParams::maxBeamWidth) - .def_readwrite("scheduler_config", &tb::TrtGptModelOptionalParams::schedulerConfig) - .def_readwrite("cache_transceiver_config", &tb::TrtGptModelOptionalParams::cacheTransceiverConfig) - .def(py::pickle(gptModelParamsGetState, gptModelParamsSetState)); - py::class_(m, "MemoryCounters") .def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference) .def_property_readonly("gpu", &tr::MemoryCounters::getGpu) diff --git a/cpp/tests/batch_manager/trtEncoderModelTest.cpp b/cpp/tests/batch_manager/trtEncoderModelTest.cpp index 6037a93148b..dce09539bd4 100644 --- a/cpp/tests/batch_manager/trtEncoderModelTest.cpp +++ b/cpp/tests/batch_manager/trtEncoderModelTest.cpp @@ -160,11 +160,9 @@ void runEncoderTest(std::unique_ptr& bufferManager, ModelConfig c requestList.push_back(request); } - TrtGptModelOptionalParams optionalParams; tensorrt_llm::executor::ExecutorConfig executorConfig{}; - optionalParams.schedulerConfig = executorConfig.getSchedulerConfig(); auto trtEncoderModel = std::make_shared( - modelConfig, worldConfig, runtime::RawEngine(engineBuffer.data(), engineBuffer.size()), logger, optionalParams); + modelConfig, worldConfig, runtime::RawEngine(engineBuffer.data(), engineBuffer.size()), logger, executorConfig); trtEncoderModel->forward(requestList); diff --git a/cpp/tests/batch_manager/trtGptModelRealDecoderTest.cpp b/cpp/tests/batch_manager/trtGptModelRealDecoderTest.cpp index 56b83f7b4cc..6b2a744a788 100644 --- a/cpp/tests/batch_manager/trtGptModelRealDecoderTest.cpp +++ b/cpp/tests/batch_manager/trtGptModelRealDecoderTest.cpp @@ -14,6 +14,7 @@ #include "tensorrt_llm/batch_manager/trtGptModelFactory.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/tllmLogger.h" @@ -621,7 +622,7 @@ RequestList runGptModelInference(std::shared_ptr& trtGptModel, std: void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds const modelIds, TrtGptModelType modelType, std::vector const& batchSizes, BeamResults const& resultsFilesBeamWidths, - TrtGptModelIfbTestType testType, int maxReqPerStep, TrtGptModelOptionalParams const& optionalParams, + TrtGptModelIfbTestType testType, int maxReqPerStep, texec::ExecutorConfig const& executorConfig, bool enableStreamingMode, bool useRandomEndId) { auto manager = BufferManager(std::make_shared()); @@ -638,8 +639,7 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds ASSERT_EQ(inputShape.nbDims, 2); ASSERT_GT(inputShape.d[0], 0); - TLLM_CHECK(optionalParams.maxBeamWidth.has_value()); - auto const maxBeamWidth = optionalParams.maxBeamWidth.value(); + auto const maxBeamWidth = executorConfig.getMaxBeamWidth(); // Load expected outputs for each beam width value auto [beamWidths, beamWidthTestData] = loadTestData(modelSpec, modelType, modelIds, resultsFilesBeamWidths, *givenInput, maxBeamWidth, useRandomEndId, modelSpec.mReplaceLogits, manager); @@ -653,7 +653,7 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds { std::cout << "=== batchSize:" << batchSize << " ===\n"; - auto trtGptModel = TrtGptModelFactory::create(modelPath, modelType, optionalParams); + auto trtGptModel = TrtGptModelFactory::create(modelPath, modelType, executorConfig, false); if (modelSpec.mKVCacheType == KVCacheType::kDISABLED) { @@ -923,38 +923,45 @@ TEST_P(ParamTest, Test) } } - TrtGptModelOptionalParams modelOptionalParams; - modelOptionalParams.kvCacheConfig.maxTokens = std::get<5>(GetParam()); - modelOptionalParams.kvCacheConfig.enableBlockReuse = modelSpec.mMaxDraftTokens > 0 || modelSpec.mKVCacheReuse; - modelOptionalParams.kvCacheConfig.freeGpuMemoryFraction = std::get<6>(GetParam()); - modelOptionalParams.kvCacheConfig.hostCacheSize = std::get<11>(GetParam()); - modelOptionalParams.enableTrtOverlap = std::get<7>(GetParam()); - modelOptionalParams.enableChunkedContext = std::get<8>(GetParam()); - modelOptionalParams.normalizeLogProbs = false; - modelOptionalParams.maxBeamWidth = beamConfig.maxBeamWidth; - modelOptionalParams.gatherGenerationLogits = modelSpec.mCollectGenerationLogits; - modelOptionalParams.extendedRuntimePerfKnobConfig.setCudaGraphMode(cudaGraphMode); - texec::CapacitySchedulerPolicy capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kMAX_UTILIZATION; - if (modelSpec.mCapacitySchedulerPolicy) - { - capacitySchedulerPolicy = modelSpec.mCapacitySchedulerPolicy.value(); - } - modelOptionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy}; + auto executorConfig = texec::ExecutorConfig{}; + + auto const maxTokens = std::get<5>(GetParam()); + auto const enableBlockReuse = modelSpec.mMaxDraftTokens > 0 || modelSpec.mKVCacheReuse; + auto const freeGpuMemoryFraction = std::get<6>(GetParam()); + auto const hostCacheSize = std::get<11>(GetParam()); + auto const kvCacheConfig = texec::KvCacheConfig{ + enableBlockReuse, maxTokens, std::nullopt, std::nullopt, freeGpuMemoryFraction, hostCacheSize}; + executorConfig.setKvCacheConfig(kvCacheConfig); + + executorConfig.setEnableTrtOverlap(std::get<7>(GetParam())); + executorConfig.setEnableChunkedContext(std::get<8>(GetParam())); + executorConfig.setNormalizeLogProbs(false); + executorConfig.setMaxBeamWidth(beamConfig.maxBeamWidth); + executorConfig.setGatherGenerationLogits(modelSpec.mCollectGenerationLogits); + auto extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig{}; + extendedRuntimePerfKnobConfig.setCudaGraphMode(cudaGraphMode); + executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig); + + auto const capacitySchedulerPolicy + = modelSpec.mCapacitySchedulerPolicy.value_or(texec::CapacitySchedulerPolicy::kMAX_UTILIZATION); + executorConfig.setSchedulerConfig(texec::SchedulerConfig{capacitySchedulerPolicy}); if (modelSpec.mSpecDecodingMode == SpeculativeDecodingMode::LookaheadDecoding()) { - modelOptionalParams.decodingConfig.setLookaheadDecodingConfig(texec::LookaheadDecodingConfig(5, 5, 5)); + auto decodingConfig = texec::DecodingConfig{}; + decodingConfig.setLookaheadDecodingConfig(texec::LookaheadDecodingConfig(5, 5, 5)); + executorConfig.setDecodingConfig(decodingConfig); } for (auto beamWidth : beamWidths) { - if (modelOptionalParams.enableTrtOverlap && beamWidth > 1) + if (executorConfig.getEnableTrtOverlap() && beamWidth > 1) { GTEST_SKIP() << "TrtOverlap is not supported with beam search"; } } - if (modelOptionalParams.enableTrtOverlap && modelSpec.mMaxDraftTokens > 0) + if (executorConfig.getEnableTrtOverlap() && modelSpec.mMaxDraftTokens > 0) { GTEST_SKIP() << "TrtOverlap is not supported with speculative decoding"; } @@ -967,7 +974,7 @@ TEST_P(ParamTest, Test) << " is not equal to the system world size"; } - runIfbTest(modelPath, modelSpec, modelIds, modelType, batchSizes, beamResults, testType, 2, modelOptionalParams, + runIfbTest(modelPath, modelSpec, modelIds, modelType, batchSizes, beamResults, testType, 2, executorConfig, enableStreamingMode, useRandomEndId); } diff --git a/cpp/tests/batch_manager/trtGptModelTest.cpp b/cpp/tests/batch_manager/trtGptModelTest.cpp index d45aedcabb1..ede1fc433ea 100644 --- a/cpp/tests/batch_manager/trtGptModelTest.cpp +++ b/cpp/tests/batch_manager/trtGptModelTest.cpp @@ -179,13 +179,13 @@ TEST_F(TrtGptModelTest, Forward) std::vector finished(mMaxNumRequests, false); - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig(executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Generate one token for the requests in request_table // We need to sync with decoder @@ -217,13 +217,13 @@ TEST_F(TrtGptModelLoraTest, Forward) std::vector finished(mMaxNumRequests, false); - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig(executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Generate one token for the requests in request_table trtGptModel->forwardAsync(requestList); @@ -238,14 +238,18 @@ TEST_F(TrtGptModelLoraTest, Forward) TEST_F(TrtGptModelTest, ForwardMaxNewTokens) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); SamplingConfig inSamplingConfig; inSamplingConfig.temperature = std::vector{2.0f}; @@ -285,17 +289,18 @@ TEST_F(TrtGptModelTest, ForwardMaxNewTokens) TEST_F(TrtGptModelTest, MaxNumTokensInChunked) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.enableChunkedContext = true; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setEnableChunkedContext(true); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); auto modelConfig = mModelConfig; mModelConfig.setMaxNumTokens(200); auto trtGptModelIfb = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); std::vector> trtGptModels{trtGptModelIfb}; for (auto trtGptModel : trtGptModels) @@ -339,14 +344,18 @@ TEST_F(TrtGptModelTest, MaxNumTokensInChunked) TEST_F(TrtGptModelTest, ForwardEndId) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); SamplingConfig inSamplingConfig; inSamplingConfig.temperature = std::vector{2.0f}; @@ -389,14 +398,17 @@ TEST_F(TrtGptModelTest, ForwardEndId) TEST_F(TrtGptModelTest, ForwardNoEoS) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kSTATIC_BATCH}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig(executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kSTATIC_BATCH}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); SamplingConfig inSamplingConfig; inSamplingConfig.topP = {0.9}; @@ -453,13 +465,13 @@ TEST_F(TrtGptModelTest, ForwardFinished) std::vector finishedFalse(mMaxNumRequests, false); std::vector finishedTrue(mMaxNumRequests, true); - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig(executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Generate one token for the requests in request_table trtGptModel->forwardAsync(requestList); @@ -484,14 +496,18 @@ TEST_F(TrtGptModelTest, ForwardFinished) TEST_F(TrtGptModelTest, ForwardStopWords) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); SamplingConfig inSamplingConfig; inSamplingConfig.temperature = std::vector{2.0f}; @@ -617,14 +633,18 @@ TEST_F(TrtGptModelTest, ForwardStopWords) TEST_F(TrtGptModelTest, ForwardBadWords) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); SamplingConfig inSamplingConfig; inSamplingConfig.temperature = std::vector{2.0f}; @@ -742,14 +762,18 @@ TEST_F(TrtGptModelTest, ForwardBadWords) TEST_F(TrtGptModelTest, ForwardEmbeddingBias) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.maxTokens = 10000; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setMaxTokens(10000); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModelIfb = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); std::vector> trtGptModels{trtGptModelIfb}; @@ -874,20 +898,23 @@ class TrtGptModelIfbHelper : public TrtGptModelInflightBatching TEST_F(TrtGptModelTest, KVCacheReuseChunked) { - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.enableChunkedContext = true; - optionalParams.kvCacheConfig.enableBlockReuse = true; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setEnableChunkedContext(true); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setEnableBlockReuse(true); + executorConfig.setKvCacheConfig(kvCacheConfig); - auto modelConfig = mModelConfig; mModelConfig.setMaxNumTokens(384); for (int const numBlocksExpectedReused : {1, 2}) { auto trtGptModelIfb = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); auto const cacheManager = trtGptModelIfb->getKVCacheManager(); auto const tokensPerBlock = cacheManager->getTokensPerBlock(); constexpr int numPrefillBlocks = 2; @@ -938,13 +965,13 @@ TEST_F(TrtGptModelTest, PauseRequestStats) RequestList requestList{llmRequest}; - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setSchedulerConfig(executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}); auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Generate one token for the requests in request_table // We need to sync with decoder @@ -1053,16 +1080,19 @@ TEST_F(TrtGptModelLogitsTest, ReturnContextLogitsWithChunkedContext) modelConfig.setMaxNumTokens(128); } - TrtGptModelOptionalParams optionalParams; - optionalParams.enableTrtOverlap = false; - optionalParams.kvCacheConfig.enableBlockReuse = true; - optionalParams.enableChunkedContext = enableChunkedContext; - optionalParams.maxBeamWidth = mBeamWidth; - optionalParams.schedulerConfig - = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(mBeamWidth); + executorConfig.setEnableChunkedContext(enableChunkedContext); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT}); + + executor::KvCacheConfig kvCacheConfig; + kvCacheConfig.setEnableBlockReuse(true); + executorConfig.setKvCacheConfig(kvCacheConfig); auto trtGptModelIfb = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, modelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Prepare input tokens std::vector input_ids; @@ -1091,12 +1121,13 @@ TEST_F(TrtGptModelLogitsTest, ReturnContextLogitsWithChunkedContext) = bufferCast(*(finishList.front()->getContextLogitsHost())); float const* const enableChunkedContextLogits = bufferCast(*(finishList.back()->getContextLogitsHost())); - for (int i = 0; i < promptLength; i++) + for (int tokenIdx = 0; tokenIdx < promptLength; tokenIdx++) { - for (int j = 0; j < vocabSizePadded; j++) + for (int vocabIdx = 0; vocabIdx < vocabSizePadded; vocabIdx++) { - size_t idx = i * vocabSizePadded + j; - EXPECT_EQ(disableChunkedContextLogits[idx], enableChunkedContextLogits[idx]); + size_t idx = tokenIdx * vocabSizePadded + vocabIdx; + EXPECT_NEAR(disableChunkedContextLogits[idx], enableChunkedContextLogits[idx], 1e-0) + << "tokenIdx=" << tokenIdx << " vocabIdx=" << vocabIdx; } } finishList.clear(); @@ -1141,18 +1172,21 @@ TEST_F(LlamaModelLADTest, SeamlessLookaheadDecoding) requestId += 1; } - TrtGptModelOptionalParams optionalParams; - optionalParams.enableChunkedContext = false; - optionalParams.enableTrtOverlap = false; - optionalParams.maxBeamWidth = 1; - optionalParams.schedulerConfig = executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}; + executor::ExecutorConfig executorConfig; + executorConfig.setEnableChunkedContext(false); + executorConfig.setEnableTrtOverlap(false); + executorConfig.setMaxBeamWidth(1); + executorConfig.setSchedulerConfig( + executor::SchedulerConfig{executor::CapacitySchedulerPolicy::kMAX_UTILIZATION}); if (initLADConfig) { - optionalParams.decodingConfig.setLookaheadDecodingConfig(executor::LookaheadDecodingConfig(5, 5, 5)); + executor::DecodingConfig decodingConfig; + decodingConfig.setLookaheadDecodingConfig(executor::LookaheadDecodingConfig(5, 5, 5)); + executorConfig.setDecodingConfig(decodingConfig); } auto trtGptModel = std::make_shared( - mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, optionalParams); + mLogger, mModelConfig, mWorldConfig, *mRawEngine, true, executorConfig, false); // Generate tokens for the requests in request_table // We need to sync with decoder diff --git a/cpp/tests/unit_tests/executor/executorTestSmall.cpp b/cpp/tests/unit_tests/executor/executorTestSmall.cpp index 25c21feb9f6..472f56e8a6a 100644 --- a/cpp/tests/unit_tests/executor/executorTestSmall.cpp +++ b/cpp/tests/unit_tests/executor/executorTestSmall.cpp @@ -1,4 +1,4 @@ -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" +#include "tensorrt_llm/batch_manager/trtGptModelInflightBatching.h" #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/rawEngine.h" @@ -6,7 +6,6 @@ #include "tests/utils/common.h" #include "tests/utils/engines.h" #include "tests/utils/executorUtils.h" -#include #include "gtest/gtest.h" @@ -92,18 +91,18 @@ std::unique_ptr> SetupDecoderTest(TrivialConstantDeco modelConfig.setPagedContextFMHA(true); auto const worldConfig = runtime::WorldConfig(); - auto optionalParams = batch_manager::TrtGptModelOptionalParams{}; - auto kvCacheConfig = batch_manager::kv_cache_manager::KvCacheConfig{}; + auto kvCacheConfig = executor::KvCacheConfig{}; + kvCacheConfig.setMaxTokens(DecoderTestShared::kKvCacheMaxTokens); + + auto const executorConfig + = tensorrt_llm::executor::ExecutorConfig(params.maxBeamWidth, executor::SchedulerConfig(), kvCacheConfig, true, + true, 1, 1, executor::BatchingType::kINFLIGHT, params.maxBatchSize, params.maxNumTokens, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, 1, std::nullopt, executor::ExtendedRuntimePerfKnobConfig(), + std::nullopt, 0, executor::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, std::nullopt, std::nullopt); - kvCacheConfig.maxTokens = DecoderTestShared::kKvCacheMaxTokens; - optionalParams.kvCacheConfig = kvCacheConfig; auto model = std::make_shared( - logger, modelConfig, worldConfig, engine, false, optionalParams); - auto const executorConfig = tensorrt_llm::executor::ExecutorConfig(params.maxBeamWidth, executor::SchedulerConfig(), - executor::KvCacheConfig{}, true, true, 1, 1, executor::BatchingType::kINFLIGHT, params.maxBatchSize, - params.maxNumTokens, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, 1, std::nullopt, - executor::ExtendedRuntimePerfKnobConfig(), std::nullopt, 0, - executor::ExecutorConfig::kDefaultMaxSeqIdleMicroseconds, std::nullopt, std::nullopt); + logger, modelConfig, worldConfig, engine, false, executorConfig, false); + return std::make_unique>( logger, rng, std::make_shared(model, executorConfig), randomLogits); } diff --git a/cpp/tests/unit_tests/executor/executorTestSmallArbitraryOutputTensors.cpp b/cpp/tests/unit_tests/executor/executorTestSmallArbitraryOutputTensors.cpp index 7b79becb9b6..b64bd775fe3 100644 --- a/cpp/tests/unit_tests/executor/executorTestSmallArbitraryOutputTensors.cpp +++ b/cpp/tests/unit_tests/executor/executorTestSmallArbitraryOutputTensors.cpp @@ -1,6 +1,5 @@ #include "include/tensorrt_llm/executor/executor.h" #include "tensorrt_llm/batch_manager/trtGptModelInflightBatching.h" -#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iBuffer.h" @@ -128,9 +127,8 @@ std::unique_ptr> SetupDecoderTest( std::vector{ executor::AdditionalModelOutput{DecoderTestShared::kTopKTensorName, params.gatherContext}}); - auto optionalParams = batch_manager::TrtGptModelOptionalParams{executorConfig, false}; auto model = std::make_shared( - logger, modelConfig, worldConfig, engine, false, optionalParams); + logger, modelConfig, worldConfig, engine, false, executorConfig, false); return std::make_unique>( logger, rng, std::make_shared(model, executorConfig), randomLogits); diff --git a/tests/unittest/bindings/test_bindings_ut.py b/tests/unittest/bindings/test_bindings_ut.py index ba2c601e4f2..37ef77e5b59 100644 --- a/tests/unittest/bindings/test_bindings_ut.py +++ b/tests/unittest/bindings/test_bindings_ut.py @@ -402,68 +402,6 @@ def test_llm_request(): assert torch.equal(llm_request.draft_logits, logits) -def test_trt_gpt_model_optional_params(): - opt_params = _tb.TrtGptModelOptionalParams() - - kv_cache_config = _tb.KvCacheConfig(10, [10], 0, 0.5, False) - opt_params.kv_cache_config = kv_cache_config - assert opt_params.kv_cache_config.free_gpu_memory_fraction == kv_cache_config.free_gpu_memory_fraction - - assert not opt_params.enable_trt_overlap - opt_params.enable_trt_overlap = True - assert opt_params.enable_trt_overlap - - assert opt_params.device_ids is None - opt_params.device_ids = [0, 1] - assert opt_params.device_ids == [0, 1] - - assert not opt_params.enable_chunked_context - opt_params.enable_chunked_context = True - assert opt_params.enable_chunked_context - - assert opt_params.normalize_log_probs - opt_params.normalize_log_probs = False - assert not opt_params.normalize_log_probs - - assert not opt_params.decoding_config.decoding_mode - opt_params.decoding_config.decoding_mode = _tb.executor.DecodingMode.TopKTopP( - ) - assert opt_params.decoding_config.decoding_mode.isTopKandTopP() - - assert not opt_params.max_beam_width - opt_params.max_beam_width = 4 - assert opt_params.max_beam_width == 4 - - assert opt_params.scheduler_config.capacity_scheduler_policy == _tb.executor.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT - assert opt_params.scheduler_config.context_chunking_policy is None - opt_params.scheduler_config = _tb.executor.SchedulerConfig( - _tb.executor.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, - _tb.executor.ContextChunkingPolicy.FIRST_COME_FIRST_SERVED) - assert opt_params.scheduler_config.capacity_scheduler_policy == _tb.executor.CapacitySchedulerPolicy.GUARANTEED_NO_EVICT - assert opt_params.scheduler_config.context_chunking_policy == _tb.executor.ContextChunkingPolicy.FIRST_COME_FIRST_SERVED - - -def test_trt_gpt_model_optional_params_ctor(): - kv_cache_config = _tb.KvCacheConfig(10, [10], 0, 0.5, False) - enable_trt_overlap = True - device_ids = [0, 1] - normalize_log_probs = False - enable_chunked_context = True - peft_cache_manager_config = _tb.PeftCacheManagerConfig() - - opt_params = _tb.TrtGptModelOptionalParams(kv_cache_config, - enable_trt_overlap, device_ids, - normalize_log_probs, - enable_chunked_context, - peft_cache_manager_config) - assert opt_params.kv_cache_config.free_gpu_memory_fraction == kv_cache_config.free_gpu_memory_fraction - assert opt_params.enable_trt_overlap - assert opt_params.device_ids == device_ids - assert opt_params.normalize_log_probs == normalize_log_probs - assert opt_params.enable_chunked_context == enable_chunked_context - assert opt_params.gpu_weights_percent == 1 - - def test_KvCacheConfig_pickle(): cache = _tb.KvCacheConfig(free_gpu_memory_fraction=0.4) cache1 = pickle.dumps(cache) @@ -472,29 +410,6 @@ def test_KvCacheConfig_pickle(): assert cache2 == cache -def test_TrtGptModelOptionalParams_pickle(): - kv_cache_config = _tb.KvCacheConfig(10, [10], 0, 0.5, False) - enable_trt_overlap = True - device_ids = [0, 1] - normalize_log_probs = False - enable_chunked_context = True - peft_cache_manager_config = _tb.PeftCacheManagerConfig() - - params1 = _tb.TrtGptModelOptionalParams(kv_cache_config, enable_trt_overlap, - device_ids, normalize_log_probs, - enable_chunked_context, - peft_cache_manager_config) - - params2 = pickle.loads(pickle.dumps(params1)) - - assert params2.kv_cache_config.free_gpu_memory_fraction == kv_cache_config.free_gpu_memory_fraction - assert params2.enable_trt_overlap - assert params2.device_ids == device_ids - assert params2.normalize_log_probs == normalize_log_probs - assert params2.enable_chunked_context == enable_chunked_context - assert params2.gpu_weights_percent == 1 - - def test_Mpicomm(): size1 = _tb.MpiComm.size() rank1 = _tb.MpiComm.rank()