@@ -828,8 +828,10 @@ class GenericLlmRequest
828828 // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase
829829 mState = mEncoderTokens .has_value () || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
830830 : LlmRequestState::kCONTEXT_INIT ;
831- mContextCurrentPosition = 0 ;
832- mPrepopulatedPromptLen = 0 ;
831+ mContextCurrentPositionTarget = 0 ;
832+ mContextCurrentPositionDraft = 0 ;
833+ mPrepopulatedPromptLenTarget = 0 ;
834+ mPrepopulatedPromptLenDraft = 0 ;
833835 mContextChunkSize = mPromptLen ;
834836 mSeqSlot .reset ();
835837 }
@@ -1049,7 +1051,7 @@ class GenericLlmRequest
10491051
10501052 [[nodiscard]] SizeType32 getPrepopulatedPromptLen () const
10511053 {
1052- return mPrepopulatedPromptLen ;
1054+ return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget ;
10531055 }
10541056
10551057 void setPrepopulatedPromptLen (SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
@@ -1066,7 +1068,10 @@ class GenericLlmRequest
10661068 " Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu" , prepopulatedPromptLen,
10671069 promptLen, mRequestId );
10681070 TLLM_CHECK (prepopulatedPromptLen < promptLen);
1069- mPrepopulatedPromptLen = prepopulatedPromptLen;
1071+
1072+ auto & prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget ;
1073+ auto & contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget ;
1074+ prePromptLen = prepopulatedPromptLen;
10701075
10711076 if (prepopulatedPromptLen > 0 )
10721077 {
@@ -1081,7 +1086,7 @@ class GenericLlmRequest
10811086 chunkSize = flooredEndPosition - prepopulatedPromptLen;
10821087 TLLM_CHECK (chunkSize <= getContextChunkSize ());
10831088 }
1084- setContextCurrentPosition ( prepopulatedPromptLen) ;
1089+ contextCurrentPosition = prepopulatedPromptLen;
10851090 setContextChunkSize (chunkSize);
10861091
10871092 if (!isLastContextChunk ())
@@ -1522,14 +1527,15 @@ class GenericLlmRequest
15221527
15231528 void setContextCurrentPosition (SizeType32 contextCurrentPosition)
15241529 {
1525- mContextCurrentPosition = contextCurrentPosition;
1530+ mContextCurrentPositionDraft = contextCurrentPosition;
1531+ mContextCurrentPositionTarget = contextCurrentPosition;
15261532 }
15271533
15281534 // / When chunked, the position of the current chunk is returned. Otherwise, only the beginning
15291535 // / or end of the context is returned.
15301536 [[nodiscard]] SizeType32 getContextCurrentPosition () const noexcept
15311537 {
1532- return mContextCurrentPosition ;
1538+ return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget ;
15331539 }
15341540
15351541 // / Return the length of the context that has not yet been processed.
@@ -1570,14 +1576,16 @@ class GenericLlmRequest
15701576 {
15711577 // The number of cached token is encountered in mContextCurrentPosition,
15721578 // so the start position of the context is mPrepopulatedPromptLen.
1573- return mContextCurrentPosition == mPrepopulatedPromptLen ;
1579+ return getContextCurrentPosition () == getPrepopulatedPromptLen () ;
15741580 }
15751581
15761582 // / Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
15771583 void moveToNextContextChunk ()
15781584 {
15791585 TLLM_CHECK_WITH_INFO (isContextInitState (), " Chunking is only possible during the context phase." );
1580- mContextCurrentPosition += getContextChunkSize ();
1586+
1587+ mContextCurrentPositionDraft += getContextChunkSize ();
1588+ mContextCurrentPositionTarget += getContextChunkSize ();
15811589 setContextChunkSize (0 );
15821590 }
15831591
@@ -1843,6 +1851,16 @@ class GenericLlmRequest
18431851 return mIsDummyRequest ;
18441852 }
18451853
1854+ void setUseDraftModel (bool useDraftModel)
1855+ {
1856+ mUseDraftModel = useDraftModel;
1857+ }
1858+
1859+ [[nodiscard]] bool useDraftModel () const
1860+ {
1861+ return mUseDraftModel ;
1862+ }
1863+
18461864 RequestIdType mRequestId ;
18471865 SizeType32 mPromptLen ;
18481866 SizeType32 mMaxNewTokens ;
@@ -1885,7 +1903,8 @@ class GenericLlmRequest
18851903 // Number of tokens already in KV cache before context phase.
18861904 // A value > 0 indicates cached KV cache blocks were reused.
18871905 // Up to inputLen - 1 tokens can be reused.
1888- SizeType32 mPrepopulatedPromptLen {0 };
1906+ SizeType32 mPrepopulatedPromptLenTarget {0 };
1907+ SizeType32 mPrepopulatedPromptLenDraft {0 };
18891908
18901909 SizeType32 mMaxSentTokenLen ;
18911910
@@ -1916,7 +1935,8 @@ class GenericLlmRequest
19161935 // The size of the context chunk must be multiple of the KV-Cache block size except the last one.
19171936 // Value `0` means Chunked-Context is disabled.
19181937 SizeType32 mContextChunkSize {0 };
1919- SizeType32 mContextCurrentPosition {0 };
1938+ SizeType32 mContextCurrentPositionTarget {0 };
1939+ SizeType32 mContextCurrentPositionDraft {0 };
19201940
19211941 std::vector<VecLogProbs> mLogProbs ; // [beamSize, seqLen]
19221942 VecLogProbs mCumLogProbs ; // [beamSize]
@@ -2017,6 +2037,8 @@ class GenericLlmRequest
20172037
20182038 bool mIsDummyRequest {false };
20192039
2040+ bool mUseDraftModel {false };
2041+
20202042private:
20212043 void initialize (VecTokens const & inputTokens, bool outputLogProbs)
20222044 {
0 commit comments