diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index d0daf9e4350..a0234cbbe49 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -31,6 +31,7 @@ #include "tensorrt_llm/runtime/worldConfig.h" #include +#include #include #include #include @@ -68,6 +69,9 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; using BlocksPerWindow = std::map>; +// Type alias for multimodal hash key (hash array + start offset) +using MmKey = std::pair, SizeType32>; + template using OptionalRef = tensorrt_llm::common::OptionalRef; @@ -107,6 +111,10 @@ struct BlockKey std::optional loraTaskId = std::nullopt; VecUniqueTokens uniqueTokens; + // Extra keys for multimodal data (similar to VLLM's approach) + // Each extra key is a pair of (mm_hash, start_offset_in_block) + std::vector extraKeys; + BlockKey() = default; explicit BlockKey(VecTokens const& tokens, std::optional loraTaskId = std::nullopt) @@ -119,23 +127,25 @@ struct BlockKey } } - BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens) - : usesExtraIds(usesExtraIds) + explicit BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens, + std::vector extraKeys = {}) + : usesExtraIds{usesExtraIds} , loraTaskId{loraTaskId} , uniqueTokens{std::move(uniqueTokens)} + , extraKeys{std::move(extraKeys)} { } bool operator==(BlockKey const& other) const noexcept { - return ( - usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens); + return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId + && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys); } int partialMatch(BlockKey const& other) const noexcept { SizeType32 numMatched{0}; - if (loraTaskId == other.loraTaskId) + if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys) { auto [matchEnd, otherMatchEnd] = std::mismatch( uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end()); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index ba3b2a94ede..d30ba27be3a 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -76,14 +76,82 @@ std::list> chopVectorIntoBlocks( return blockedVectors; } +inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept +{ + return static_cast((hashPart >> (24 - byteIdx * 8)) & 0xFF); +} + +std::vector generateBlockHashExtraKeys( + tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx) +{ + auto const multimodalHashes = llmRequest.getMultimodalHashes(); + auto const multimodalPositions = llmRequest.getMultimodalPositions(); + auto const multimodalLengths = llmRequest.getMultimodalLengths(); + + if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes) + || (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty() + || !(*multimodalLengths) || (*multimodalLengths)->empty()) + { + return {}; + } + + if ((*multimodalHashes)->size() != (*multimodalPositions)->size() + || (*multimodalPositions)->size() != (*multimodalLengths)->size()) + { + TLLM_LOG_WARNING("Multimodal data arrays have mismatched sizes"); + return {}; + } + + std::vector extraKeys; // MmKey = std::pair, SizeType32> + extraKeys.reserve((*multimodalPositions)->size()); + std::array mmHashArray; + + for (size_t i = 0; i < (*multimodalPositions)->size(); ++i) + { + auto const& startPos = (*(*multimodalPositions))[i]; + auto const& length = (*(*multimodalLengths))[i]; + auto const& mmHashVector = (*(*multimodalHashes))[i]; + + TLLM_CHECK_WITH_INFO(mmHashVector.size() == 8, "Multimodal hash vector has unexpected size: %zu (expected 8)", + mmHashVector.size()); + + // mmHashVector[j] comes from Python's int(hex_chunk, 16) + // where hex_chunk like "00010203" means 0x00 is MSB and 0x03 is LSB (big endian) + // Convert 8x 32-bit integers into a 32-byte array preserving Blake3 hash byte order + // Example: hashPart = 0x00010203 → mmHashArray[0:3] = [0x00, 0x01, 0x02, 0x03] + for (size_t j = 0; j < 8; ++j) + { + auto const& hashPart = mmHashVector[j]; + for (uint8_t byteIdx = 0; byteIdx < 4; ++byteIdx) + { + mmHashArray[j * 4 + byteIdx] = getNthByte(hashPart, byteIdx); + } + } + + // Check if this multimodal content overlaps with the current block + if (endTokenIdx > startPos && startTokenIdx < startPos + length) + { + SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos; + extraKeys.emplace_back(mmHashArray, mmStartInBlock); + } + } + + return extraKeys; +} + std::vector buildBlockKeys( std::list& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest) { std::vector blockKeys; + + SizeType32 currentTokenIdx = 0; for (auto& uniqueTokens : blockedUniqueTokens) { - blockKeys.emplace_back( - llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), std::move(uniqueTokens)); + auto extraKeys = generateBlockHashExtraKeys(llmRequest, currentTokenIdx, currentTokenIdx + uniqueTokens.size()); + currentTokenIdx += uniqueTokens.size(); + + blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), + std::move(uniqueTokens), std::move(extraKeys)); } return blockKeys; } @@ -92,9 +160,11 @@ std::vector buildBlockKeys( namespace tensorrt_llm::batch_manager::kv_cache_manager { - size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept { + // Hashing algorithm adapted from StackOverflow: + // https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key + // Constants provide very good distribution - each input bit affects each output bit with ~50% probability. size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9); for (auto const& uniqueToken : blockKey.uniqueTokens) @@ -122,7 +192,36 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no c = c ^ (c >> 31); seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2); } - // TODO: support external hashes for multimodal + + // Add extra keys for multimodal data mixing in external multimodal item hash and token offset within this sequence + // block + if (!blockKey.extraKeys.empty()) + { + for (auto const& [mmHash, startOffset] : blockKey.extraKeys) + { + // Hash the multimodal hash array in 32-bit chunks (more efficient) + for (size_t i = 0; i < 32; i += 4) + { + // Combine 4 bytes into a 32-bit word (construct as little endian order) + uint32_t word = static_cast(mmHash[i]) | (static_cast(mmHash[i + 1]) << 8) + | (static_cast(mmHash[i + 2]) << 16) | (static_cast(mmHash[i + 3]) << 24); + + // Mix the word into the seed + word = ((word >> 16) ^ word) * 0x45d9f3b; + word = ((word >> 16) ^ word) * 0x45d9f3b; + word = (word >> 16) ^ word; + seed ^= word + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + // Hash the start offset + uint64_t e = static_cast(startOffset); + e = (e ^ (e >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); + e = (e ^ (e >> 27)) * UINT64_C(0x94d049bb133111eb); + e = e ^ (e >> 31); + seed ^= e + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + } + return seed; } diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 08ab45145d5..ba10a17b26d 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -1034,6 +1034,182 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } +TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest) +{ + using VecTokenExtraIds = LlmRequest::VecTokenExtraIds; + + auto constexpr numLayers = 12; + auto constexpr numKvHeads = 6; + auto constexpr sizePerHead = 16; + auto constexpr tokensPerBlock = 4; + auto constexpr maxBlocksPerSeq = 4; + auto constexpr blocksInPrimaryPool = 16; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + auto constexpr numReturnSequences = 1; + auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto constexpr beamWidth = 1; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, + maxNumSequences, stream, maxAttentionWindow, beamWidth, + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, + onboardBlocks); + blockManager.allocatePools(false); + + EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); + EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + SizeType32 constexpr maxNewTokens{0}; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + // Create multimodal hash data (256-bit hash = 8 int32 values) + auto multimodalHashes = std::make_shared>>(std::vector>{ + {0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666} // Hash 1 + }); + auto multimodalPositions + = std::make_shared>(std::vector{2}); // Start at token 2 + auto multimodalLengths = std::make_shared>(std::vector{4}); // Length 4 tokens + // assume prompt id starts from 100 + auto inputTokens = std::make_shared(VecTokens{100, 101, 102, 103, 104, 105, 0, 1, 2}); + auto const inputLength = static_cast(inputTokens->size()); + LlmRequest::RequestIdType requestId{0}; + auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + + GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + /////////////////////////////////////////////////////////////////////////// + // add request and then remove it + auto constexpr beamIdx = 0; + auto promptLen0 = llmRequest0->getNumTokens(beamIdx); + auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); + EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); + llmRequest0->addNewToken(3, beamIdx); + llmRequest0->addNewToken(4, beamIdx); + auto numTokens = llmRequest0->getNumTokens(beamIdx); + auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); + EXPECT_EQ(numBlocks, 3); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + // Input: [100, 101, 102, 103, 104, 105, 0, 1, 2] (9 tokens) + // Multimodal: starts at token 2, length 4 → [102, 103, 104, 105] + + // Block 0: [100, 101, 102, 103] ← Contains multimodal (102, 103) + // Block 1: [104, 105, 0, 1] ← Contains multimodal (104, 105) + // Block 2: [2, 3, 4] ← No multimodal + blockManager.releaseBlocks(seq0, llmRequest0); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + /////////////////////////////////////////////////////////////////////////// + // new request with same tokens and same multimodal hash - should reuse + requestId = 1; + auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // should reuse blocks 0, 1 and get new block 3 + auto promptLen1 = llmRequest1->getNumTokens(beamIdx); + auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock); + EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3})); + llmRequest1->addNewToken(3, beamIdx); + llmRequest1->addNewToken(4, beamIdx); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + // block 3 matches block 2 and will be freed + blockManager.releaseBlocks(seq1, llmRequest1); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 2: Different multimodal hash + requestId = 2; + auto multimodalHashes2 + = std::make_shared>>(std::vector>{ + {0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2 + }); + auto multimodalPositions2 + = std::make_shared>(std::vector{2}); // Start at token 2 + auto multimodalLengths2 = std::make_shared>(std::vector{4}); // Length 4 tokens + auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + + GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + // no reuse, get new blocks 4, 5, 6 + auto promptLen2 = llmRequest2->getNumTokens(beamIdx); + auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0); + EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6})); + llmRequest2->addNewToken(9, beamIdx); + numTokens = llmRequest2->getNumTokens(beamIdx); + numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 3: Multiple multimodal hashes and partial reuse + requestId = 3; + auto multimodalHashes3 + = std::make_shared>>(std::vector>{ + {0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666}, // Hash 1 + {0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2 + }); + auto multimodalPositions3 + = std::make_shared>(std::vector{2, 4}); // Start at token 2 and 4 + auto multimodalLengths3 + = std::make_shared>(std::vector{2, 2}); // Length 2 tokens + + auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, + std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, + std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences); + GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + // reuse block 0, get new blocks 7, 8 + auto promptLen3 = llmRequest3->getNumTokens(beamIdx); + auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + EXPECT_EQ(llmRequest3->getContextCurrentPosition(), + tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset + EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8})); + llmRequest3->addNewToken(11, beamIdx); + numTokens = llmRequest3->getNumTokens(beamIdx); + numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2); + + // clean up + blockManager.releaseBlocks(seq2, llmRequest2); + blockManager.releaseBlocks(seq3, llmRequest3); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); +} + TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest) { // tc::Logger::getLogger()->setLevel(tc::Logger::Level::DEBUG); diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py index 1dc86cdd1d2..d6387f81908 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py @@ -26,6 +26,83 @@ from torchvision.transforms import Normalize, Resize, ToTensor from tensorrt_llm._torch.modules.embedding import Embedding +from tensorrt_llm.inputs.multimodal import MultimodalParams +from tensorrt_llm.logger import logger + + +def find_uncached_mm_embeds( + mm_embeds: List[torch.Tensor], + multimodal_params: List[MultimodalParams]) -> torch.Tensor: + """ + Find the uncached multimodal mm_embeds from multimodal_params for each batch. + Args: + - mm_embeds: List[torch.Tensor] + - multimodal_params: List[MultimodalParams] + Returns: + - sliced_mm_embeds: List[torch.Tensor] + When kv_cache reuse is disabled or model not enabled/support kv_cache reuse, return the full mm_embeds. + Note: + - Current implementation assumes chunk prefill is disabled. To support chunk prefill, we might need to slightly modify the logic (see TODO below). + """ + # Current support two batching modes: + # 1. Pre-concatenated mm_embeds for each batch, i.e., len(mm_embeds) == 1 + # 2. Individual mm_embeds for each multimodal param, i.e., len(mm_embeds) == len(multimodal_params) + if len(mm_embeds) > 1 and len(mm_embeds) != len(multimodal_params): + raise ValueError( + f"Number of mm_embeds ({len(mm_embeds)}) does not match number of multimodal params ({len(multimodal_params)})." + ) + + if not multimodal_params or multimodal_params[0].multimodal_runtime is None: + # No slicing, return the full mm_embeds + return mm_embeds + + total_cached_mm_tokens = sum([ + param.multimodal_runtime.num_cached_mm_tokens + for param in multimodal_params + ]) + if total_cached_mm_tokens == 0: + # No cached tokens, return the full mm_embeds + # TODO: support chunk prefill for multimodal, then we need to extract full mm_embeds for each CHUNK + logger.debug( + "No multimodal cached tokens can be reused, return the full mm_embeds" + ) + return mm_embeds + + if total_cached_mm_tokens == sum([ + param.multimodal_runtime.total_mm_tokens + for param in multimodal_params + ]): + # All tokens are cached, return empty list + logger.debug( + "All multimodal tokens cached, skipping vision encoder forward") + return [] + + # Partial caching, return the sliced mm_embeds + current_pos = 0 + slices = [] + for param in multimodal_params: + runtime = param.multimodal_runtime + slices.append((current_pos + runtime.num_cached_mm_tokens, + current_pos + runtime.total_mm_tokens)) + if len(mm_embeds + ) == 1: # pre-concatenated mm_embeds, need global offset + current_pos += runtime.total_mm_tokens + + sliced_mm_embeds = [] + if len(mm_embeds) == 1: + for start, end in slices: + sliced_mm_embeds.append(mm_embeds[0][start:end]) + else: # slice each mm_embeds individually + for i, (start, end) in enumerate(slices): + sliced_mm_embeds.append(mm_embeds[i][start:end]) + + if len(mm_embeds) == 1: + sliced_mm_embeds = [torch.cat(sliced_mm_embeds, dim=0)] + + logger.debug( + f"Partial caching, return sliced_mm_embeds: {sliced_mm_embeds[0].shape}" + ) + return sliced_mm_embeds def fuse_input_embeds( @@ -69,6 +146,12 @@ def fuse_input_embeds( text_token_mask = ~mm_token_mask text_token_indices = torch.where(text_token_mask)[0] mm_token_indices = torch.where(mm_token_mask)[0] + if len(mm_token_indices) != mm_embed.shape[0]: + raise ValueError( + f"Multimodal token count mismatch: found {len(mm_token_indices)} image tokens in input_ids " + f"but received {mm_embed.shape[0]} image embeddings. " + "This is likely due to KV cache reuse, chunk prefill, or other optimizations that " + "cause token count mismatches within the inference batch.") text_embed = embedding_layer(input_ids[text_token_indices]) input_embeds = torch.empty(input_ids.shape[0], diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 2d63a4bbf92..25a2778f8b8 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -18,7 +18,8 @@ from ..attention_backend import AttentionMetadata from ..model_config import ModelConfig from .modeling_auto import AutoModelForCausalLM -from .modeling_multimodal_utils import fuse_input_embeds +from .modeling_multimodal_utils import (find_uncached_mm_embeds, + fuse_input_embeds) from .modeling_utils import register_auto_model DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1' @@ -601,6 +602,8 @@ def forward( mrope_config = self._parse_and_concat_mrope_config( multimodal_params, num_context_requests, num_generation_requests) + mm_embeds = find_uncached_mm_embeds( + mm_embeds, multimodal_params[:num_context_requests]) if 'mrope_position_deltas' in kwargs: mrope_config['mrope_position_deltas'] = kwargs[ diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 98eb2e870d4..90529a2dd94 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -21,7 +21,8 @@ from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, torch_dtype_to_str, trace_func) -from tensorrt_llm.inputs.multimodal import MultimodalParams +from tensorrt_llm.inputs.multimodal import (MultimodalParams, + MultimodalRuntimeData) from tensorrt_llm.logger import logger from tensorrt_llm.lora_manager import LoraConfig, LoraModelConfig from tensorrt_llm.mapping import Mapping @@ -1143,8 +1144,16 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq.append(past_seen_token_num) # Multimodal + # TODO: enable chunk prefill for multimodal (maybe need to pass prompt_tokens to MultimodalRuntimeData) + py_multimodal_runtime = MultimodalRuntimeData( + mm_token_lengths=request.multimodal_lengths, + mm_token_positions=request.multimodal_positions, + num_cached_tokens=past_seen_token_num + ) if request.multimodal_hashes is not None else None + multimodal_params = MultimodalParams( - multimodal_data=request.py_multimodal_data) + multimodal_data=request.py_multimodal_data, + multimodal_runtime=py_multimodal_runtime) multimodal_params.to_device("multimodal_data", "cuda", pin_memory=True) diff --git a/tensorrt_llm/inputs/multimodal.py b/tensorrt_llm/inputs/multimodal.py index a6b29a9f018..19d55ae7744 100644 --- a/tensorrt_llm/inputs/multimodal.py +++ b/tensorrt_llm/inputs/multimodal.py @@ -82,6 +82,72 @@ def to_tensor(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: torch.tensor(self.multimodal_lengths, dtype=torch.int32)) +@dataclass +class MultimodalRuntimeData: + """Runtime data for tracking multimodal token caching and reuse per request sequence. + + This class tracks which multimodal tokens are cached vs. need to be processed + for each request sequence during KV cache reuse scenarios. + + Attributes: + num_cached_tokens: Total number of cached tokens for this sequence + mm_token_lengths: Length of each multimodal token chunk + mm_token_positions: Starting positions of each multimodal token chunk + prompt_tokens: Current iteration of prompt tokens for this sequence (optional). Need it for chunk prefill if enabled (#TODO) + num_cached_mm_tokens: Number of multimodal tokens that are cached in this iteration (computed) + total_mm_tokens: Total number of multimodal tokens in this sequence (computed) + """ + num_cached_tokens: int + mm_token_lengths: List[int] + mm_token_positions: List[int] + + # TODO: support chunk prefill for multimodal + # When chunk prefill is enabled, we need to pass the prompt tokens for current chunk and mask to find the included mm tokens + prompt_tokens: Optional[List[int]] = None + + num_cached_mm_tokens: Optional[int] = None + total_mm_tokens: Optional[int] = None + + def __post_init__(self): + # Validate input data + if len(self.mm_token_positions) != len(self.mm_token_lengths): + raise ValueError( + f"mm_token_positions ({len(self.mm_token_positions)}) and mm_token_lengths ({len(self.mm_token_lengths)}) must have the same length" + ) + + if self.num_cached_tokens < 0: + raise ValueError( + f"num_cached_tokens must be non-negative, got {self.num_cached_tokens}" + ) + + if any(length <= 0 for length in self.mm_token_lengths): + raise ValueError( + f"All mm_token_lengths must be positive, got {self.mm_token_lengths}" + ) + + if any(pos < 0 for pos in self.mm_token_positions): + raise ValueError( + f"All mm_token_positions must be non-negative, got {self.mm_token_positions}" + ) + + if self.num_cached_mm_tokens is None: + # Compute cached multimodal tokens based on positions and cached tokens + self.num_cached_mm_tokens = 0 + for pos, length in zip(self.mm_token_positions, + self.mm_token_lengths): + if pos + length <= self.num_cached_tokens: + self.num_cached_mm_tokens += length + elif pos < self.num_cached_tokens: + # Partial overlap - only count the cached portion + self.num_cached_mm_tokens += self.num_cached_tokens - pos + + if self.num_cached_mm_tokens > self.num_cached_tokens: + raise ValueError( + f"num_cached_mm_tokens ({self.num_cached_mm_tokens}) must be less than or equal to " + f"num_cached_tokens ({self.num_cached_tokens})") + self.total_mm_tokens = sum(self.mm_token_lengths) + + @dataclass class MultimodalParams: """Unified container for multimodal parameters. @@ -117,6 +183,7 @@ class MultimodalParams: multimodal_input: Optional[MultimodalInput] = None multimodal_data: Optional[Dict[str, Any]] = field(default_factory=dict) + multimodal_runtime: Optional[MultimodalRuntimeData] = None def __post_init__(self): """Ensure default values are properly set.""" diff --git a/tests/unittest/_torch/multimodal/test_kvcache_reuse.py b/tests/unittest/_torch/multimodal/test_kvcache_reuse.py new file mode 100644 index 00000000000..0eb0d5f9ca4 --- /dev/null +++ b/tests/unittest/_torch/multimodal/test_kvcache_reuse.py @@ -0,0 +1,257 @@ +from unittest.mock import Mock + +import pytest +import torch + +# Import the function to test +from tensorrt_llm._torch.models.modeling_multimodal_utils import \ + find_uncached_mm_embeds +from tensorrt_llm.inputs.multimodal import (MultimodalParams, + MultimodalRuntimeData) + + +class TestMultimodalRuntimeData: + """Test cases for MultimodalRuntimeData computation logic, specifically num_cached_mm_tokens.""" + + def test_fully_cached_multimodal_tokens(self): + """Test when all multimodal tokens are cached.""" + runtime = MultimodalRuntimeData( + num_cached_tokens=20, + mm_token_lengths=[5, 8, 7], # Total: 20 tokens + mm_token_positions=[0, 5, 13] # Positions: 0-5, 5-13, 13-20 + ) + + # All tokens should be cached since num_cached_tokens (20) >= all positions + lengths + assert runtime.num_cached_mm_tokens == 20 + assert runtime.total_mm_tokens == 20 + + def test_no_cached_multimodal_tokens(self): + """Test when no multimodal tokens are cached.""" + runtime = MultimodalRuntimeData( + num_cached_tokens=10, + mm_token_lengths=[5, 8, 7], # Total: 20 tokens + mm_token_positions=[10, 18, 30] # All positions > num_cached_tokens + ) + + # No multimodal tokens should be cached + assert runtime.num_cached_mm_tokens == 0 + assert runtime.total_mm_tokens == 20 + + def test_complex_scenario_with_multiple_chunks(self): + """Test a complex scenario with many chunks and various caching states.""" + runtime = MultimodalRuntimeData( + num_cached_tokens=30, + mm_token_lengths=[3, 4, 5, 6, 7, 8], # Total: 33 tokens + mm_token_positions=[ + 0, 5, 10, 15, 25, 35 + ] # Positions: 0-3, 5-9, 10-15, 15-21, 25-32, 35-43 + ) + + # Expected caching: + # Chunk 0: fully cached (3 tokens) + # Chunk 1: fully cached (4 tokens) + # Chunk 2: fully cached (5 tokens) + # Chunk 3: fully cached (6 tokens) + # Chunk 4: partially cached (30-25=5 out of 7 tokens) + # Chunk 5: not cached + expected_cached = 3 + 4 + 5 + 6 + 5 # 23 tokens + assert runtime.num_cached_mm_tokens == expected_cached + assert runtime.total_mm_tokens == 33 + + +class TestFindUncachedMmEmbed: + """Focused test cases for find_uncached_mm_embeds function - testing edge cases and potential bugs.""" + + def create_mock_runtime(self, num_cached_mm_tokens: int, + total_mm_tokens: int): + """Helper to create a mock MultimodalRuntimeData.""" + runtime = Mock(spec=MultimodalRuntimeData) + runtime.num_cached_mm_tokens = num_cached_mm_tokens + runtime.total_mm_tokens = total_mm_tokens + return runtime + + def create_multimodal_params(self, num_cached_mm_tokens: int, + total_mm_tokens: int): + """Helper to create MultimodalParams with runtime data.""" + runtime = self.create_mock_runtime(num_cached_mm_tokens, + total_mm_tokens) + return MultimodalParams(multimodal_runtime=runtime) + + def test_mm_embed_not_batched(self): + """ + Test individual batching mode where each mm_embed corresponds to one param. + This tests the case where len(mm_embeds) == len(multimodal_params) > 1. + """ + mm_embeds = [ + torch.randn(10, 512), # Batch 1: 10 tokens + torch.randn(15, 512), # Batch 2: 15 tokens + torch.randn(8, 512) # Batch 3: 8 tokens + ] + multimodal_params = [ + self.create_multimodal_params(3, 10), # 3 cached, 7 uncached + self.create_multimodal_params(8, 15), # 8 cached, 7 uncached + self.create_multimodal_params(0, 8) # 0 cached, 8 uncached + ] + + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + + # Should return individual slices for each batch + assert len(result) == 3 + assert result[0].shape == (7, 512) # 10 - 3 = 7 + assert result[1].shape == (7, 512) # 15 - 8 = 7 + assert result[2].shape == (8, 512) # 8 - 0 = 8 + + # Verify the slices are correct + torch.testing.assert_close(result[0], mm_embeds[0][3:10]) + torch.testing.assert_close(result[1], mm_embeds[1][8:15]) + torch.testing.assert_close(result[2], mm_embeds[2][0:8]) + + def test_mm_embed_batched(self): + """ + Test batching (concatenated) mm_embeds with fused mm_embeds for each batch. + This tests the case where len(mm_embeds) == 1 + """ + mm_embeds = [torch.randn(33, + 512)] # Pre-concatenated: 10 + 13 + 10 tokens + multimodal_params = [ + self.create_multimodal_params(4, 10), # 4 cached, 6 uncached + self.create_multimodal_params(7, 13), # 7 cached, 6 uncached + self.create_multimodal_params(3, 10) # 3 cached, 7 uncached + ] + + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + + # Expected slices: + # Batch 1: [4:10] = 6 tokens + # Batch 2: [10+7:10+13] = [17:23] = 6 tokens + # Batch 3: [23+3:23+10] = [26:33] = 7 tokens + # Total: 6 + 6 + 7 = 19 tokens + assert len(result) == 1 + assert result[0].shape == (19, 512) + + # Verify the slices are correct + expected = torch.cat( + [ + mm_embeds[0][4:10], # Batch 1: 6 tokens + mm_embeds[0][17:23], # Batch 2: 6 tokens + mm_embeds[0][26:33] # Batch 3: 7 tokens + ], + dim=0) + torch.testing.assert_close(result[0], expected) + + def test_mixed_caching_with_fully_cached_batches(self): + """ + Test mixed scenarios where some batches are fully cached (should be skipped). + """ + mm_embeds = [torch.randn(25, 512)] # Pre-concatenated: 8 + 9 + 8 tokens + multimodal_params = [ + self.create_multimodal_params(8, + 8), # All cached - should be skipped + self.create_multimodal_params(3, 9), # 3 cached, 6 uncached + self.create_multimodal_params(8, + 8) # All cached - should be skipped + ] + + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + + # Only batch 2 should contribute: [8+3:8+9] = [11:17] = 6 tokens + assert len(result) == 1 + assert result[0].shape == (6, 512) + + # Verify the slice is correct + torch.testing.assert_close(result[0], mm_embeds[0][11:17]) + + def test_all_batches_fully_cached(self): + """ + Test edge case where all batches are fully cached. + """ + mm_embeds = [torch.randn(30, + 512)] # Pre-concatenated: 10 + 10 + 10 tokens + multimodal_params = [ + self.create_multimodal_params(10, 10), # All cached + self.create_multimodal_params(10, 10), # All cached + self.create_multimodal_params(10, 10) # All cached + ] + + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + + # Should return empty list + assert result == [] + + def test_no_batches_cached(self): + """ + Test edge case where no batches have any cached tokens. + """ + mm_embeds = [torch.randn(30, + 512)] # Pre-concatenated: 10 + 10 + 10 tokens + multimodal_params = [ + self.create_multimodal_params(0, 10), # No cached + self.create_multimodal_params(0, 10), # No cached + self.create_multimodal_params(0, 10) # No cached + ] + + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + + # Should return the full embeddings + assert result == mm_embeds + + def test_error_handling_mismatched_counts(self): + """ + Test error handling when mm_embeds and multimodal_params counts don't match + in individual batching mode. + """ + mm_embeds = [torch.randn(10, 512), torch.randn(15, 512)] # 2 embeddings + multimodal_params = [self.create_multimodal_params(0, + 10)] # Only 1 param + + with pytest.raises( + ValueError, + match= + "Number of mm_embeds \\(2\\) does not match number of multimodal params \\(1\\)" + ): + find_uncached_mm_embeds(mm_embeds, multimodal_params) + + def test_single_batch_scenarios(self): + """ + Test various single batch scenarios. + """ + # Single batch, no caching + mm_embeds = [torch.randn(20, 512)] + multimodal_params = [self.create_multimodal_params(0, 20)] + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + assert result == mm_embeds + + # Single batch, partial caching + multimodal_params = [self.create_multimodal_params(5, 20)] + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + assert len(result) == 1 + assert result[0].shape == (15, 512) + torch.testing.assert_close(result[0], mm_embeds[0][5:20]) + + # Single batch, all cached + multimodal_params = [self.create_multimodal_params(20, 20)] + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + assert result == [] + + def test_different_devices(self): + """ + Test with tensors on different devices (if CUDA is available). + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Test CPU tensors + mm_embeds = [torch.randn(10, 512, device='cpu')] + multimodal_params = [self.create_multimodal_params(3, 10)] + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + assert result[0].device == mm_embeds[0].device + + # Test CUDA tensors + mm_embeds = [torch.randn(10, 512, device='cuda')] + multimodal_params = [self.create_multimodal_params(3, 10)] + result = find_uncached_mm_embeds(mm_embeds, multimodal_params) + assert result[0].device == mm_embeds[0].device + + +if __name__ == "__main__": + pytest.main([__file__])