Skip to content

Commit f1a4443

Browse files
Funatiqevezhier
authored andcommitted
[None] [refactor] Minor cleanup and improvements (NVIDIA#7619)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 2569453 commit f1a4443

File tree

9 files changed

+38
-49
lines changed

9 files changed

+38
-49
lines changed

cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
* SPDX-License-Identifier: Apache-2.0
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,19 +20,15 @@
2020
#include "tensorrt_llm/batch_manager/common.h"
2121
#include "tensorrt_llm/common/algorithm.h"
2222
#include "tensorrt_llm/common/optionalRef.h"
23-
#include "tensorrt_llm/runtime/bufferManager.h"
23+
#include "tensorrt_llm/executor/executor.h"
2424
#include "tensorrt_llm/runtime/common.h"
2525
#include "tensorrt_llm/runtime/iTensor.h"
2626
#include "tensorrt_llm/runtime/modelConfig.h"
2727
#include "tensorrt_llm/runtime/worldConfig.h"
2828

2929
namespace tensorrt_llm::runtime
3030
{
31-
class DecodingInput;
32-
class DecodingOutput;
33-
class GptDecoderBatched;
3431
class SamplingConfig;
35-
class SpeculativeDecodingMode;
3632

3733
namespace decoder
3834
{
@@ -56,10 +52,6 @@ class CreateNewDecoderRequests : Algorithm
5652
using CudaStream = tensorrt_llm::runtime::CudaStream;
5753
using TensorPtr = runtime::ITensor::SharedPtr;
5854
using SharedConstPtr = runtime::ITensor::SharedConstPtr;
59-
using DecodingInput = runtime::DecodingInput;
60-
using DecodingOutput = runtime::DecodingOutput;
61-
using SpeculativeDecodingMode = runtime::SpeculativeDecodingMode;
62-
using GptDecoderBatched = runtime::GptDecoderBatched;
6355
template <typename T>
6456
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
6557

@@ -70,16 +62,15 @@ class CreateNewDecoderRequests : Algorithm
7062
{
7163
}
7264

73-
std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
65+
[[nodiscard]] std::tuple<TensorPtr, std::vector<SamplingConfig>, std::vector<SharedConstPtr>,
7466
std::vector<executor::LookaheadDecodingConfig>>
7567
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
7668
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
7769
nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers, runtime::decoder::DecoderState& decoderState,
7870
CudaStream const& runtimeStream, CudaStream const& decoderStream, SizeType32 maxSequenceLength,
7971
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers) const;
8072

81-
[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
82-
std::vector<executor::LookaheadDecodingConfig>>
73+
[[nodiscard]] std::tuple<std::vector<SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
8374
createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
8475
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
8576
nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -29,6 +29,8 @@
2929
#include <cassert>
3030
#include <chrono>
3131
#include <cstdint>
32+
#include <cstring>
33+
#include <list>
3234
#include <memory>
3335
#include <optional>
3436
#include <utility>
@@ -56,9 +58,9 @@ enum class LlmRequestState : int32_t
5658
/// used in layer-wise transmission
5759
kDISAGG_GENERATION_TRANS_COMPLETE = 12, ///< Kv cache transmission are finished
5860
kGENERATION_IN_PROGRESS = 13, ///< Generation phase is in progress
59-
kGENERATION_TO_COMPLETE = 14, ///< Generation phase is to be completed
6061

6162
// schedulable states ends
63+
kGENERATION_TO_COMPLETE = 14, ///< Generation phase is to be completed
6264
kGENERATION_COMPLETE = 20, ///< Generation phase completed
6365
kDISAGG_CONTEXT_TRANS_IN_PROGRESS = 21, ///< Waiting context-only request transmitting the kv cache,
6466
/// after computation finished
@@ -1075,7 +1077,6 @@ class GenericLlmRequest
10751077
TLLM_CHECK_WITH_INFO(prepopulatedPromptLen < promptLen,
10761078
"Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen,
10771079
promptLen, mRequestId);
1078-
TLLM_CHECK(prepopulatedPromptLen < promptLen);
10791080

10801081
auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget;
10811082
auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget;
@@ -1116,9 +1117,9 @@ class GenericLlmRequest
11161117
mDraftLogits = draftLogits;
11171118
}
11181119

1119-
[[nodiscard]] SizeType32 getNumDraftTokens() const
1120+
[[nodiscard]] SizeType32 getNumDraftTokens() const noexcept
11201121
{
1121-
return hasDraftTokens() ? mDraftTokens->size() : 0;
1122+
return hasDraftTokens() ? static_cast<SizeType32>(mDraftTokens->size()) : 0;
11221123
}
11231124

11241125
void discardDraftTokens(SizeType32 numTokensToDiscard)
@@ -1379,17 +1380,17 @@ class GenericLlmRequest
13791380
mGenerationLogitsFragments.push_back(genLogits);
13801381
}
13811382

1382-
SizeType32 getGenerationLogitsFragmentsSize()
1383+
[[nodiscard]] SizeType32 getGenerationLogitsFragmentsSize() const noexcept
13831384
{
1384-
return mGenerationLogitsFragments.size();
1385+
return static_cast<SizeType32>(mGenerationLogitsFragments.size());
13851386
}
13861387

1387-
void clearGenerationLogitsFragments()
1388+
void clearGenerationLogitsFragments() noexcept
13881389
{
13891390
mGenerationLogitsFragments.clear();
13901391
}
13911392

1392-
bool hasAdditionalOutputs()
1393+
[[nodiscard]] bool hasAdditionalOutputs() const noexcept
13931394
{
13941395
return !mAdditionalContextOutputTensors.empty() || !mAdditionalGenerationOutputTensors.empty();
13951396
}

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1478,7 +1478,8 @@ class CacheTransceiverConfig
14781478
class ExecutorConfig
14791479
{
14801480
public:
1481-
static constexpr uint64_t kDefaultMaxSeqIdleMicroseconds = 180000000;
1481+
static constexpr uint64_t kDefaultMaxSeqIdleMicroseconds
1482+
= std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::minutes(3)).count();
14821483

14831484
static constexpr SizeType32 kDefaultIterStatsMaxIterations = 1000;
14841485

cpp/include/tensorrt_llm/runtime/lookaheadModule.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#include "tensorrt_llm/executor/executor.h"
2020
#include "tensorrt_llm/runtime/common.h"
2121
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
22-
#include <memory>
2322

2423
namespace tensorrt_llm::runtime
2524
{
@@ -29,7 +28,6 @@ class LookaheadModule : public SpeculativeDecodingModule
2928
public:
3029
explicit LookaheadModule(SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens) noexcept
3130
: SpeculativeDecodingModule(maxDraftPathLen, maxDecodingDraftTokens, maxDecodingDraftTokens)
32-
, mExecutionConfig()
3331
{
3432
}
3533

@@ -43,7 +41,7 @@ class LookaheadModule : public SpeculativeDecodingModule
4341
mExecutionConfig = config;
4442
}
4543

46-
executor::LookaheadDecodingConfig const getExecutionConfig() const
44+
[[nodiscard]] executor::LookaheadDecodingConfig const& getExecutionConfig() const
4745
{
4846
return mExecutionConfig;
4947
}

cpp/include/tensorrt_llm/runtime/modelConfig.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "tensorrt_llm/runtime/lookaheadModule.h"
2222
#include "tensorrt_llm/runtime/loraModule.h"
2323
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
24+
#include "tensorrt_llm/runtime/speculativeDecodingModule.h"
2425

2526
#include <NvInferRuntime.h>
2627
#include <array>

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ using namespace tensorrt_llm::runtime;
3939

4040
namespace tc = tensorrt_llm::common;
4141
namespace te = tensorrt_llm::executor;
42-
namespace tk = tensorrt_llm::kernels;
4342
namespace tr = tensorrt_llm::runtime;
4443

4544
namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/runtime/bufferView.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class BufferView : virtual public IBuffer
3939

4040
if (offset + size > mBuffer->getSize())
4141
{
42-
throw std::out_of_range(std::string("slice ") + std::to_string(offset + size) + " exceeds buffer size "
43-
+ std::to_string(mBuffer->getSize()));
42+
throw std::out_of_range(std::string("offset ") + std::to_string(offset) + std::string(" + size ")
43+
+ std::to_string(size) + " exceeds buffer size " + std::to_string(mBuffer->getSize()));
4444
}
4545
}
4646

cpp/tests/e2e_tests/batch_manager/trtGptModelRealDecoderTest.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ void verifyOutput(RequestList const& finishedRequestList,
284284
}
285285

286286
// Pick a different endId at random from one of the expected tokens
287-
std::vector<TokenIdType> pickRandomEndIds(TestData const& testData, TrtGptModelType const& modelType,
288-
std::vector<SizeType32> const& givenInputLengths, SizeType32 const maxNewTokens, bool replaceLogits)
287+
std::vector<TokenIdType> pickRandomEndIds(TestData const& testData, std::vector<SizeType32> const& givenInputLengths,
288+
SizeType32 const maxNewTokens, bool replaceLogits)
289289
{
290290
auto const nbGivenInputs = testData.nbGivenInputs;
291291
auto const beamWidth = testData.beamWidth;
@@ -328,9 +328,9 @@ std::vector<TokenIdType> pickRandomEndIds(TestData const& testData, TrtGptModelT
328328
return endIds;
329329
}
330330

331-
TestData loadTestData(ModelSpec const& modelSpec, TrtGptModelType const& modelType, ModelIds const modelIds,
332-
BeamResult const& beamResult, ITensor const& givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId,
333-
bool const replaceLogits, BufferManager& manager)
331+
TestData loadTestData(ModelSpec const& modelSpec, ModelIds const modelIds, BeamResult const& beamResult,
332+
ITensor const& givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits,
333+
BufferManager& manager)
334334
{
335335
auto const [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths(givenInput, modelIds.padId);
336336
auto const& [beamWidth, resultsFile, contextLogitsFile, genLogitsFile, cumLogProbsFile, logProbsFile] = beamResult;
@@ -353,7 +353,7 @@ TestData loadTestData(ModelSpec const& modelSpec, TrtGptModelType const& modelTy
353353

354354
if (useRandomEndId)
355355
{
356-
testData.endIds = pickRandomEndIds(testData, modelType, givenInputLengths, maxNewTokens, replaceLogits);
356+
testData.endIds = pickRandomEndIds(testData, givenInputLengths, maxNewTokens, replaceLogits);
357357
}
358358
else
359359
{
@@ -409,9 +409,8 @@ TestData loadTestData(ModelSpec const& modelSpec, TrtGptModelType const& modelTy
409409
}
410410

411411
std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> loadTestData(ModelSpec const& modelSpec,
412-
TrtGptModelType const& modelType, ModelIds const modelIds, BeamResults const& resultsFilesBeamWidths,
413-
ITensor const& givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits,
414-
BufferManager& manager)
412+
ModelIds const modelIds, BeamResults const& resultsFilesBeamWidths, ITensor const& givenInput,
413+
SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits, BufferManager& manager)
415414
{
416415
// Map between beam width, and expected results for that beam width
417416
std::unordered_map<SizeType32, TestData> beamWidthTestData;
@@ -424,8 +423,8 @@ std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> lo
424423
EXPECT_EQ(std::find(beamWidths.begin(), beamWidths.end(), beamWidth), beamWidths.end());
425424
beamWidths.push_back(beamWidth);
426425

427-
auto testData = loadTestData(modelSpec, modelType, modelIds, beamResult, givenInput, maxBeamWidth,
428-
useRandomEndId, replaceLogits, manager);
426+
auto testData = loadTestData(
427+
modelSpec, modelIds, beamResult, givenInput, maxBeamWidth, useRandomEndId, replaceLogits, manager);
429428
beamWidthTestData.emplace(beamWidth, std::move(testData));
430429
}
431430

@@ -435,9 +434,8 @@ std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> lo
435434
RequestList runGptModelInference(std::shared_ptr<TrtGptModel>& trtGptModel, std::vector<SizeType32> const& beamWidths,
436435
std::unordered_map<SizeType32, TestData> const& beamWidthTestData, SizeType32 batchSize, SizeType32 nbGivenInputs,
437436
SizeType32 maxInputLength, SizeType32 padId, std::vector<SizeType32> const& givenInputLengths,
438-
TokenIdType const* givenInputData, ModelSpec const& modelSpec, TrtGptModelIfbTestType testType,
439-
TrtGptModelType modelType, int maxReqPerStep, bool prepopulateKVCache, bool enableStreamingMode,
440-
bool enableBlockReuse)
437+
TokenIdType const* givenInputData, ModelSpec const& modelSpec, TrtGptModelIfbTestType testType, int maxReqPerStep,
438+
bool prepopulateKVCache, bool enableStreamingMode, bool enableBlockReuse)
441439
{
442440
// Fill the requests using givenInput
443441
// requestList will have batchSize requests
@@ -641,8 +639,8 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds
641639

642640
auto const maxBeamWidth = executorConfig.getMaxBeamWidth();
643641
// Load expected outputs for each beam width value
644-
auto [beamWidths, beamWidthTestData] = loadTestData(modelSpec, modelType, modelIds, resultsFilesBeamWidths,
645-
*givenInput, maxBeamWidth, useRandomEndId, modelSpec.mReplaceLogits, manager);
642+
auto [beamWidths, beamWidthTestData] = loadTestData(modelSpec, modelIds, resultsFilesBeamWidths, *givenInput,
643+
maxBeamWidth, useRandomEndId, modelSpec.mReplaceLogits, manager);
646644

647645
int const worldSize = modelSpec.mTPSize * modelSpec.mPPSize * modelSpec.mCPSize;
648646
auto const worldConfig = WorldConfig::mpi(worldSize, modelSpec.mTPSize, modelSpec.mPPSize, modelSpec.mCPSize);
@@ -663,14 +661,14 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds
663661
// Prepopulate KV cache for speculative decoding test
664662
bool const prepopulateKVCache = modelSpec.mMaxDraftTokens > 0;
665663
auto finishedRequestList = runGptModelInference(trtGptModel, beamWidths, beamWidthTestData, batchSize,
666-
nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, modelType,
667-
maxReqPerStep, prepopulateKVCache, enableStreamingMode, modelSpec.mKVCacheReuse);
664+
nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, maxReqPerStep,
665+
prepopulateKVCache, enableStreamingMode, modelSpec.mKVCacheReuse);
668666

669667
if (prepopulateKVCache)
670668
{
671669
// Call the 2nd time with prefilled KV cache
672670
finishedRequestList = runGptModelInference(trtGptModel, beamWidths, beamWidthTestData, batchSize,
673-
nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, modelType,
671+
nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType,
674672
maxReqPerStep, false, enableStreamingMode, modelSpec.mKVCacheReuse);
675673
}
676674

cpp/tests/unit_tests/batch_manager/llmRequestTest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ TEST_F(LlmRequestTest, fromExecutorRequest)
5656
EXPECT_EQ(llmReq.getState(), tb::LlmRequestState::kCONTEXT_INIT);
5757
EXPECT_FALSE(llmReq.mSeqSlot);
5858
// No speculative decoding config, draft tokens should be empty
59-
EXPECT_EQ(llmReq.getDraftTokens()->size(), 0);
59+
EXPECT_EQ(llmReq.getNumDraftTokens(), 0);
6060
EXPECT_FALSE(llmReq.getEmbeddingBias().has_value());
6161
EXPECT_FALSE(llmReq.getBadWordsList().has_value());
6262
EXPECT_FALSE(llmReq.getStopWordsList().has_value());

0 commit comments

Comments
 (0)