@@ -34,9 +34,9 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
3434    OptFloat const & topPMin, std::optional<TokenIdType> const & topPResetIds, OptFloat const & topPDecay,
3535    std::optional<RandomSeedType> const & seed, OptFloat const & temperature, OptSize32 const & minTokens,
3636    OptFloat const & beamSearchDiversityRate, OptFloat const & repetitionPenalty, OptFloat const & presencePenalty,
37-     OptFloat const & frequencyPenalty, OptFloat  const & lengthPenalty, OptSize32  const & earlyStopping ,
38-     OptSize32 const & noRepeatNgramSize , OptSize32 const & numReturnSequences, OptFloat  const & minP ,
39-     OptVec<SizeType32> const & beamWidthArray)
37+     OptFloat const & frequencyPenalty, OptSize32  const & promptIgnoreLength, OptFloat  const & lengthPenalty ,
38+     OptSize32 const & earlyStopping , OptSize32 const & noRepeatNgramSize, OptSize32  const & numReturnSequences ,
39+     OptFloat  const & minP,  OptVec<SizeType32> const & beamWidthArray)
4040    : mBeamWidth (checkBeamWidth(beamWidth))
4141    , mTopK (checkTopK(topK))
4242    , mTopP (checkTopP(topP))
@@ -50,6 +50,7 @@ SamplingConfig::SamplingConfig(SizeType32 beamWidth, OptSize32 const& topK, OptF
5050    , mRepetitionPenalty (checkRepetitionPenalty(repetitionPenalty))
5151    , mPresencePenalty (presencePenalty)
5252    , mFrequencyPenalty (frequencyPenalty)
53+     , mPromptIgnoreLength (checkPromptIgnoreLength(promptIgnoreLength))
5354    , mLengthPenalty (checkLengthPenalty(lengthPenalty))
5455    , mEarlyStopping (checkEarlyStopping(earlyStopping))
5556    , mNoRepeatNgramSize (checkNoRepeatNgramSize(noRepeatNgramSize))
@@ -67,9 +68,10 @@ bool SamplingConfig::operator==(SamplingConfig const& other) const
6768        && mTemperature  == other.mTemperature  && mMinTokens  == other.mMinTokens 
6869        && mBeamSearchDiversityRate  == other.mBeamSearchDiversityRate  && mRepetitionPenalty  == other.mRepetitionPenalty 
6970        && mPresencePenalty  == other.mPresencePenalty  && mFrequencyPenalty  == other.mFrequencyPenalty 
70-         && mLengthPenalty  == other.mLengthPenalty  && mEarlyStopping  == other.mEarlyStopping 
71-         && mNoRepeatNgramSize  == other.mNoRepeatNgramSize  && mNumReturnSequences  == other.mNumReturnSequences 
72-         && mMinP  == other.mMinP  && mBeamWidthArray  == other.mBeamWidthArray ;
71+         && mPromptIgnoreLength  == other.mPromptIgnoreLength  && mLengthPenalty  == other.mLengthPenalty 
72+         && mEarlyStopping  == other.mEarlyStopping  && mNoRepeatNgramSize  == other.mNoRepeatNgramSize 
73+         && mNumReturnSequences  == other.mNumReturnSequences  && mMinP  == other.mMinP 
74+         && mBeamWidthArray  == other.mBeamWidthArray ;
7375}
7476
7577//  Getters
@@ -143,6 +145,11 @@ OptFloat SamplingConfig::getFrequencyPenalty() const
143145    return  mFrequencyPenalty ;
144146}
145147
148+ OptSize32 SamplingConfig::getPromptIgnoreLength () const 
149+ {
150+     return  mPromptIgnoreLength ;
151+ }
152+ 
146153OptFloat SamplingConfig::getLengthPenalty () const 
147154{
148155    return  mLengthPenalty ;
@@ -240,6 +247,11 @@ void SamplingConfig::setFrequencyPenalty(OptFloat const& frequencyPenalty)
240247    mFrequencyPenalty  = frequencyPenalty;
241248}
242249
250+ void  SamplingConfig::setPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
251+ {
252+     mPromptIgnoreLength  = checkPromptIgnoreLength (promptIgnoreLength);
253+ }
254+ 
243255void  SamplingConfig::setLengthPenalty (OptFloat const & lengthPenalty)
244256{
245257    mLengthPenalty  = lengthPenalty; //  TODO: re-enable `checkLengthPenalty` later
@@ -362,6 +374,15 @@ OptFloat const& SamplingConfig::checkRepetitionPenalty(OptFloat const& repetitio
362374    return  repetitionpenalty;
363375}
364376
377+ OptSize32 const & SamplingConfig::checkPromptIgnoreLength (OptSize32 const & promptIgnoreLength)
378+ {
379+     if  (promptIgnoreLength.has_value ())
380+     {
381+         TLLM_CHECK (promptIgnoreLength.value () >= 0 );
382+     }
383+     return  promptIgnoreLength;
384+ }
385+ 
365386OptFloat const & SamplingConfig::checkLengthPenalty (OptFloat const & lengthPenalty)
366387{
367388    if  (lengthPenalty.has_value ())
0 commit comments