@@ -284,8 +284,8 @@ void verifyOutput(RequestList const& finishedRequestList,
284284}
285285
286286// Pick a different endId at random from one of the expected tokens
287- std::vector<TokenIdType> pickRandomEndIds (TestData const & testData, TrtGptModelType const & modelType ,
288- std::vector<SizeType32> const & givenInputLengths, SizeType32 const maxNewTokens, bool replaceLogits)
287+ std::vector<TokenIdType> pickRandomEndIds (TestData const & testData, std::vector<SizeType32> const & givenInputLengths ,
288+ SizeType32 const maxNewTokens, bool replaceLogits)
289289{
290290 auto const nbGivenInputs = testData.nbGivenInputs ;
291291 auto const beamWidth = testData.beamWidth ;
@@ -328,9 +328,9 @@ std::vector<TokenIdType> pickRandomEndIds(TestData const& testData, TrtGptModelT
328328 return endIds;
329329}
330330
331- TestData loadTestData (ModelSpec const & modelSpec, TrtGptModelType const & modelType, ModelIds const modelIds ,
332- BeamResult const & beamResult, ITensor const & givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId,
333- bool const replaceLogits, BufferManager& manager)
331+ TestData loadTestData (ModelSpec const & modelSpec, ModelIds const modelIds, BeamResult const & beamResult ,
332+ ITensor const & givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits ,
333+ BufferManager& manager)
334334{
335335 auto const [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths (givenInput, modelIds.padId );
336336 auto const & [beamWidth, resultsFile, contextLogitsFile, genLogitsFile, cumLogProbsFile, logProbsFile] = beamResult;
@@ -353,7 +353,7 @@ TestData loadTestData(ModelSpec const& modelSpec, TrtGptModelType const& modelTy
353353
354354 if (useRandomEndId)
355355 {
356- testData.endIds = pickRandomEndIds (testData, modelType, givenInputLengths, maxNewTokens, replaceLogits);
356+ testData.endIds = pickRandomEndIds (testData, givenInputLengths, maxNewTokens, replaceLogits);
357357 }
358358 else
359359 {
@@ -409,9 +409,8 @@ TestData loadTestData(ModelSpec const& modelSpec, TrtGptModelType const& modelTy
409409}
410410
411411std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> loadTestData (ModelSpec const & modelSpec,
412- TrtGptModelType const & modelType, ModelIds const modelIds, BeamResults const & resultsFilesBeamWidths,
413- ITensor const & givenInput, SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits,
414- BufferManager& manager)
412+ ModelIds const modelIds, BeamResults const & resultsFilesBeamWidths, ITensor const & givenInput,
413+ SizeType32 const maxBeamWidth, bool const useRandomEndId, bool const replaceLogits, BufferManager& manager)
415414{
416415 // Map between beam width, and expected results for that beam width
417416 std::unordered_map<SizeType32, TestData> beamWidthTestData;
@@ -424,8 +423,8 @@ std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> lo
424423 EXPECT_EQ (std::find (beamWidths.begin (), beamWidths.end (), beamWidth), beamWidths.end ());
425424 beamWidths.push_back (beamWidth);
426425
427- auto testData = loadTestData (modelSpec, modelType, modelIds, beamResult, givenInput, maxBeamWidth,
428- useRandomEndId, replaceLogits, manager);
426+ auto testData = loadTestData (
427+ modelSpec, modelIds, beamResult, givenInput, maxBeamWidth, useRandomEndId, replaceLogits, manager);
429428 beamWidthTestData.emplace (beamWidth, std::move (testData));
430429 }
431430
@@ -435,9 +434,8 @@ std::tuple<std::vector<SizeType32>, std::unordered_map<SizeType32, TestData>> lo
435434RequestList runGptModelInference (std::shared_ptr<TrtGptModel>& trtGptModel, std::vector<SizeType32> const & beamWidths,
436435 std::unordered_map<SizeType32, TestData> const & beamWidthTestData, SizeType32 batchSize, SizeType32 nbGivenInputs,
437436 SizeType32 maxInputLength, SizeType32 padId, std::vector<SizeType32> const & givenInputLengths,
438- TokenIdType const * givenInputData, ModelSpec const & modelSpec, TrtGptModelIfbTestType testType,
439- TrtGptModelType modelType, int maxReqPerStep, bool prepopulateKVCache, bool enableStreamingMode,
440- bool enableBlockReuse)
437+ TokenIdType const * givenInputData, ModelSpec const & modelSpec, TrtGptModelIfbTestType testType, int maxReqPerStep,
438+ bool prepopulateKVCache, bool enableStreamingMode, bool enableBlockReuse)
441439{
442440 // Fill the requests using givenInput
443441 // requestList will have batchSize requests
@@ -641,8 +639,8 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds
641639
642640 auto const maxBeamWidth = executorConfig.getMaxBeamWidth ();
643641 // Load expected outputs for each beam width value
644- auto [beamWidths, beamWidthTestData] = loadTestData (modelSpec, modelType, modelIds, resultsFilesBeamWidths,
645- *givenInput, maxBeamWidth, useRandomEndId, modelSpec.mReplaceLogits , manager);
642+ auto [beamWidths, beamWidthTestData] = loadTestData (modelSpec, modelIds, resultsFilesBeamWidths, *givenInput ,
643+ maxBeamWidth, useRandomEndId, modelSpec.mReplaceLogits , manager);
646644
647645 int const worldSize = modelSpec.mTPSize * modelSpec.mPPSize * modelSpec.mCPSize ;
648646 auto const worldConfig = WorldConfig::mpi (worldSize, modelSpec.mTPSize , modelSpec.mPPSize , modelSpec.mCPSize );
@@ -663,14 +661,14 @@ void runIfbTest(fs::path const& modelPath, ModelSpec const& modelSpec, ModelIds
663661 // Prepopulate KV cache for speculative decoding test
664662 bool const prepopulateKVCache = modelSpec.mMaxDraftTokens > 0 ;
665663 auto finishedRequestList = runGptModelInference (trtGptModel, beamWidths, beamWidthTestData, batchSize,
666- nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, modelType ,
667- maxReqPerStep, prepopulateKVCache, enableStreamingMode, modelSpec.mKVCacheReuse );
664+ nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, maxReqPerStep ,
665+ prepopulateKVCache, enableStreamingMode, modelSpec.mKVCacheReuse );
668666
669667 if (prepopulateKVCache)
670668 {
671669 // Call the 2nd time with prefilled KV cache
672670 finishedRequestList = runGptModelInference (trtGptModel, beamWidths, beamWidthTestData, batchSize,
673- nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType, modelType,
671+ nbGivenInputs, maxInputLength, padId, givenInputLengths, givenInputData, modelSpec, testType,
674672 maxReqPerStep, false , enableStreamingMode, modelSpec.mKVCacheReuse );
675673 }
676674
0 commit comments