Skip to content
Merged
140 changes: 0 additions & 140 deletions cpp/include/tensorrt_llm/batch_manager/trtGptModelOptionalParams.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
Expand Down
11 changes: 6 additions & 5 deletions cpp/tensorrt_llm/batch_manager/trtEncoderModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/tllmRuntime.h"
Expand All @@ -39,14 +40,14 @@ namespace tensorrt_llm::batch_manager

TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig,
runtime::RawEngine const& rawEngine, std::shared_ptr<nvinfer1::ILogger> logger,
TrtGptModelOptionalParams const& optionalParams)
: TrtGptModel(modelConfig, worldConfig, optionalParams)
executor::ExecutorConfig const& executorConfig)
: TrtGptModel(modelConfig, worldConfig, executorConfig)
, mModelConfig{modelConfig}
, mWorldConfig{worldConfig}
, mDevice{runtime::utils::initDevice(worldConfig)}
, mLogger{logger ? std::move(logger) : std::make_shared<TllmLogger>()}
, mRuntime{std::make_shared<TllmRuntime>(
rawEngine, mLogger.get(), optionalParams.useGpuDirectStorage, optionalParams.gpuWeightsPercent)}
rawEngine, mLogger.get(), executorConfig.getUseGpuDirectStorage(), executorConfig.getGpuWeightsPercent())}
, mNumMicroBatches{1}
, mNumBuffers{mNumMicroBatches}
, mCopyBufferManager{std::make_shared<CudaStream>()}
Expand Down Expand Up @@ -75,8 +76,8 @@ TrtEncoderModel::TrtEncoderModel(runtime::ModelConfig const& modelConfig, WorldC
// handling of maximizing utilization or pause/evict
// TODO: finer control on encoder requests scheduling
mCapacityScheduler = std::make_unique<tensorrt_llm::batch_manager::CapacityScheduler>(
getMaxBatchSize() * mNumMicroBatches, optionalParams.schedulerConfig.getCapacitySchedulerPolicy(), false, false,
LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
getMaxBatchSize() * mNumMicroBatches, executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy(), false,
false, LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);

mMicroBatchScheduler = std::make_unique<tensorrt_llm::batch_manager::MicroBatchScheduler>(
std::nullopt, mModelConfig.getMaxInputLen(), LlmRequestState::kENCODER_INIT, LlmRequestState::kCONTEXT_INIT);
Expand Down
3 changes: 1 addition & 2 deletions cpp/tensorrt_llm/batch_manager/trtEncoderModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

#pragma once

#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/rawEngine.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
#include "trtGptModel.h"
Expand Down Expand Up @@ -47,7 +46,7 @@ class TrtEncoderModel : public TrtGptModel

TrtEncoderModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
runtime::RawEngine const& rawEngine, std::shared_ptr<nvinfer1::ILogger> logger,
TrtGptModelOptionalParams const& optionalParams);
executor::ExecutorConfig const& executorConfig);

~TrtEncoderModel() override;

Expand Down
38 changes: 19 additions & 19 deletions cpp/tensorrt_llm/batch_manager/trtGptModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#pragma once

#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stlUtils.h"
#include "tensorrt_llm/executor/executor.h"
Expand Down Expand Up @@ -52,23 +51,23 @@ class TrtGptModel : public executor::Model
using SizeType32 = tensorrt_llm::runtime::SizeType32;

TrtGptModel(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
TrtGptModelOptionalParams const& optionalParams)
: mMaxBatchSize{optionalParams.maxBatchSize.value_or(modelConfig.getMaxBatchSize())}
, mMaxBeamWidth{optionalParams.maxBeamWidth.value_or(modelConfig.getMaxBeamWidth())}
executor::ExecutorConfig const& executorConfig)
: mMaxBatchSize{executorConfig.getMaxBatchSize().value_or(modelConfig.getMaxBatchSize())}
, mMaxBeamWidth{executorConfig.getMaxBeamWidth()}
, mMaxSequenceLen{modelConfig.getMaxSequenceLen()}
, mMaxDraftLen{modelConfig.getMaxDecodingDraftTokens()}
, mVocabSizePadded{modelConfig.getVocabSizePadded(worldConfig.getSize())}
, mNormalizeLogProbs{optionalParams.normalizeLogProbs}
, mEnableTrtOverlap{optionalParams.enableTrtOverlap}
, mCudaGraphMode{optionalParams.extendedRuntimePerfKnobConfig.getCudaGraphMode()}
, mNormalizeLogProbs{executorConfig.getNormalizeLogProbs()}
, mEnableTrtOverlap{executorConfig.getEnableTrtOverlap()}
, mCudaGraphMode{executorConfig.getExtendedRuntimePerfKnobConfig().getCudaGraphMode()}
{
TLLM_CHECK_WITH_INFO(mMaxBeamWidth <= modelConfig.getMaxBeamWidth(),
"Runtime configured max beam width (%d) must not exceed engine max beam width (%d)", mMaxBeamWidth,
modelConfig.getMaxBeamWidth());
TLLM_CHECK_WITH_INFO(mMaxBatchSize <= modelConfig.getMaxBatchSize(),
"Runtime configured max batch size (%d) must not exceed engine max batch size (%d)", mMaxBatchSize,
modelConfig.getMaxBatchSize());
if (optionalParams.enableTrtOverlap)
if (executorConfig.getEnableTrtOverlap())
{
if (mMaxBeamWidth > 1)
{
Expand All @@ -85,10 +84,11 @@ class TrtGptModel : public executor::Model
}

mMaxAttentionWindow = 0;
if (optionalParams.kvCacheConfig.maxAttentionWindowVec.has_value())
if (executorConfig.getKvCacheConfig().getMaxAttentionWindowVec().has_value())
{
bool warning = false;
for (int maxAttenWin : optionalParams.kvCacheConfig.maxAttentionWindowVec.value())
auto const& maxAttentionWindowVec = executorConfig.getKvCacheConfig().getMaxAttentionWindowVec();
for (int maxAttenWin : maxAttentionWindowVec.value())
{
mMaxAttentionWindowVec.push_back(std::min(maxAttenWin, mMaxSequenceLen));
mMaxAttentionWindow = std::max(mMaxAttentionWindow, mMaxAttentionWindowVec.back());
Expand All @@ -112,8 +112,8 @@ class TrtGptModel : public executor::Model
mMaxAttentionWindow = mMaxSequenceLen;
}

mSinkTokenLen = optionalParams.kvCacheConfig.sinkTokenLength.has_value()
? optionalParams.kvCacheConfig.sinkTokenLength.value()
mSinkTokenLen = executorConfig.getKvCacheConfig().getSinkTokenLength().has_value()
? executorConfig.getKvCacheConfig().getSinkTokenLength().value()
: 0;

mMaxNumSequences = mMaxBatchSize * worldConfig.getPipelineParallelism();
Expand All @@ -136,26 +136,26 @@ class TrtGptModel : public executor::Model
TLLM_LOG_INFO("TRTGptModel normalizeLogProbs: %d", mNormalizeLogProbs);

mMaxNumTokens = modelConfig.getMaxNumTokens();
if (optionalParams.maxNumTokens && mMaxNumTokens)
if (executorConfig.getMaxNumTokens().has_value() && mMaxNumTokens)
{
if (optionalParams.maxNumTokens.value() > mMaxNumTokens.value())
if (executorConfig.getMaxNumTokens().value() > mMaxNumTokens.value())
{
TLLM_LOG_WARNING(
"Runtime configured max num tokens (%d) is larger than model max num tokens (%d) and will be "
"ignored.",
optionalParams.maxNumTokens.value(), mMaxNumTokens.value());
executorConfig.getMaxNumTokens().value(), mMaxNumTokens.value());
}
else
{
mMaxNumTokens = optionalParams.maxNumTokens;
mMaxNumTokens = executorConfig.getMaxNumTokens();
}
}
if (mMaxNumTokens)
{
TLLM_LOG_INFO("TRTGptModel maxNumTokens: %d", mMaxNumTokens.value());
}

if (optionalParams.enableChunkedContext)
if (executorConfig.getEnableChunkedContext())
{
mMaxInputLen = mMaxSequenceLen - 1;
TLLM_LOG_INFO(
Expand Down Expand Up @@ -199,9 +199,9 @@ class TrtGptModel : public executor::Model
using tensorrt_llm::common::stl_utils::toString;

TLLM_LOG_INFO("Capacity Scheduler Policy: %s",
toString(optionalParams.schedulerConfig.getCapacitySchedulerPolicy()).c_str());
toString(executorConfig.getSchedulerConfig().getCapacitySchedulerPolicy()).c_str());
TLLM_LOG_INFO("Context Chunking Scheduler Policy: %s",
toString(optionalParams.schedulerConfig.getContextChunkingPolicy()).c_str());
toString(executorConfig.getSchedulerConfig().getContextChunkingPolicy()).c_str());
}

[[nodiscard]] std::optional<SizeType32> getMaxNumTokens() const
Expand Down
Loading