Skip to content

Commit 3830a1e

Browse files
committed
refactor: Add setBeamWidth method to DecoderState
- Introduced setBeamWidth method in DecoderState to allow setting the beam width for specific requests in a batch. - Updated CreateNewDecoderRequests to utilize the new setBeamWidth method, improving code clarity and maintainability. Signed-off-by: Robin Kobus <[email protected]>
1 parent 31019da commit 3830a1e

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

cpp/include/tensorrt_llm/runtime/decoderState.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ class DecoderState
173173
//! @brief Workspace for beam search in streaming mode.
174174
[[nodiscard]] BeamSearchBuffers const& getBeamSearchBuffers() const;
175175

176+
//! @brief Set the beam width for a specific request in the batch.
177+
//! @param batchIdx The index of the request in the batch.
178+
//! @param beamWidth The beam width for the specified request.
179+
void setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth);
180+
176181
//! @brief Cache indirection input for beam search.
177182
[[nodiscard]] TensorPtr getCacheIndirectionInput() const;
178183

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,8 +581,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
581581
{
582582
llmReq->mSamplingConfig.normalizeLogProbs = mIsNormalizeLogProbs;
583583

584-
auto& dJointInput = decoderState.getJointDecodingInput();
585-
586584
TLLM_CHECK(llmReq->mSeqSlot.has_value());
587585
auto const batchSlot = llmReq->mSeqSlot.value();
588586
auto const batchSize = decoderState.getMaxBatchSize();
@@ -595,7 +593,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
595593
TLLM_CHECK_WITH_INFO(beamWidth <= maxBeamWidth,
596594
tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.",
597595
beamWidth, maxBeamWidth));
598-
dJointInput.beamWidths.at(batchSlot) = beamWidth;
596+
decoderState.setBeamWidth(batchSlot, beamWidth);
599597

600598
auto const promptLen = llmReq->getPromptLen();
601599

@@ -626,6 +624,8 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
626624
decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
627625
}
628626

627+
auto& dJointInput = decoderState.getJointDecodingInput();
628+
629629
auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep;
630630
initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens,
631631
maxSequenceLength, decoderBufferManager);

cpp/tensorrt_llm/runtime/decoderState.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,11 @@ void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationS
642642
mJointDecodingInput->generationSteps = generationSteps;
643643
}
644644

645+
void DecoderState::setBeamWidth(SizeType32 batchIdx, SizeType32 beamWidth)
646+
{
647+
mJointDecodingInput->beamWidths.at(batchIdx) = beamWidth;
648+
}
649+
645650
DecodingInput& DecoderState::getJointDecodingInput() const
646651
{
647652
return *mJointDecodingInput;

0 commit comments

Comments
 (0)