diff --git a/benchmarks/cpp/disaggServerBenchmark.cpp b/benchmarks/cpp/disaggServerBenchmark.cpp index d0b5fb8c864..ab009802757 100644 --- a/benchmarks/cpp/disaggServerBenchmark.cpp +++ b/benchmarks/cpp/disaggServerBenchmark.cpp @@ -636,6 +636,8 @@ class DisaggExecutorServer : texec::DecodingMode::Auto(), benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices)); executorConfig.setExtendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig); + executorConfig.setCacheTransceiverConfig( + texec::CacheTransceiverConfig(texec::CacheTransceiverConfig::BackendType::DEFAULT)); constexpr int maxIterationsForRequestStats = 1000; if (mEnableCollectKvCacheTransferTime) { diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index 6f9c2f82dd6..c39fee6f940 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -70,28 +70,20 @@ class BaseCacheTransceiver class CacheTransceiver : public BaseCacheTransceiver { public: - enum class CommType : std::uint8_t - { - UNKNOWN = 0, - MPI = 1, - UCX = 2, - NIXL = 3 - }; - - CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType, + CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType = executor::kv_cache::CacheState::AttentionType::kDEFAULT, std::optional cacheTransceiverConfig = std::nullopt); - CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType, - std::vector numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, - runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType, + CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector numKvHeadsPerLayer, + SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig, + nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType = executor::kv_cache::CacheState::AttentionType::kDEFAULT, std::optional cacheTransceiverConfig = std::nullopt) - : CacheTransceiver(cacheManager, commType, + : CacheTransceiver(cacheManager, executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig, dataType, attentionType, cacheTransceiverConfig) { @@ -118,7 +110,6 @@ class CacheTransceiver : public BaseCacheTransceiver void setContextState(LlmRequest* llmRequest); - CommType mCommType; std::unique_ptr mDataResponder; std::unique_ptr mDataRequester; std::vector>> mResponderFutures; diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 1cd651cd07c..bba3c31a014 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1430,18 +1430,29 @@ class LogitsPostProcessorConfig class CacheTransceiverConfig { public: - explicit CacheTransceiverConfig(std::optional maxNumTokens = std::nullopt); + enum class BackendType : std::uint8_t + { + DEFAULT = 0, + MPI = 1, + UCX = 2, + NIXL = 3 + }; + explicit CacheTransceiverConfig( + std::optional backendType = std::nullopt, std::optional maxNumTokens = std::nullopt); bool operator==(CacheTransceiverConfig const& other) const; + void setBackendType(std::optional backendType); + void setMaxTokensInBuffer(std::optional maxTokensInBuffer); - [[nodiscard]] std::optional getMaxNumTokens() const; - void setMaxNumTokens(size_t maxNumTokens); + [[nodiscard]] std::optional getMaxTokensInBuffer() const; + [[nodiscard]] std::optional getBackendType() const; private: + std::optional mBackendType; /// @brief The maximum number of tokens that the CacheTransceiver's pre-allocated buffer can hold. If the number of /// kvCache tokens to be transferred for a single request is greater than this value, the performance of the cache /// transfer may be degraded. - std::optional mMaxNumTokens; + std::optional mMaxTokensInBuffer; }; /// @brief Configuration class for the model executor diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 51b06feaf71..1a3aed54f41 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -210,7 +210,7 @@ CacheTransBufferManager::CacheTransBufferManager( { auto poolIdx = mCacheManager->getBlockManager().getLayerPoolIdx(layerId); auto windowSize = static_cast(mCacheManager->getBlockManager().getPoolWindowSize(poolIdx)); - auto validTokenNum = windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value(); + auto validTokenNum = (windowSize < maxNumTokens.value() ? windowSize : maxNumTokens.value()); bufferSizeFromMaxNumToken += validTokenNum * kvCacheByteSizePerTokenPerLayer; } } @@ -230,26 +230,37 @@ CacheTransBufferManager::CacheTransBufferManager( TLLM_LOG_INFO( "CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " "mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d " - "mUseFabricMemory:%d", + "mUseFabricMemory:%d mDataType:%d", maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, - mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory); - bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache(); + mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory, mDataType); - TLLM_CHECK_WITH_INFO(to_allocate, "CacheTransBufferManager: to_allocate is false"); allocateBuffer(); } -size_t CacheTransBufferManager::preAllocBufferSize(std::optional maxNumTokens) +size_t CacheTransBufferManager::preAllocBufferSize( + std::map const& cacheSizeBytesPerTokenPerWindow, + std::optional const& cacheTransceiverConfig) { - bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache(); - if (!to_allocate) + if (!cacheTransceiverConfig.has_value()) { return 0; } + if (!cacheTransceiverConfig->getBackendType().has_value()) + { + return 0; + } + auto maxNumTokens = cacheTransceiverConfig->getMaxTokensInBuffer(); size_t TransferBufferSize = common::getEnvMemSizeForKVCacheTransferBuffer(); if (maxNumTokens.has_value()) { - TransferBufferSize = maxNumTokens.value(); + TransferBufferSize = 0; + for (auto const& [windowSize, cacheSizeBytesPerToken] : cacheSizeBytesPerTokenPerWindow) + { + auto validTokenNum + = (static_cast(windowSize) < maxNumTokens.value() ? static_cast(windowSize) + : maxNumTokens.value()); + TransferBufferSize += validTokenNum * cacheSizeBytesPerToken; + } } bool useFabricMemory = FabricMemory::supportFbaricMemory() && (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())); @@ -329,6 +340,14 @@ std::tuple, size_t, bool> CacheTransBuf size_t bufferCoverTargetNum = std::min( static_cast(targetNum), mTransferBufferSize / (targetBufferEleSize * common::getDTypeSize(mDataType))); TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); + if (bufferCoverTargetNum < static_cast(targetNum)) + { + TLLM_LOG_WARNING( + "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic buffer, " + "it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may " + "be degraded", + bufferCoverTargetNum, targetNum); + } if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h index d534e2b4ac6..e7b050388fe 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @@ -18,6 +18,7 @@ #pragma once #include "tensorrt_llm/batch_manager/kvCacheManager.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" #include @@ -59,7 +60,8 @@ class CacheTransBufferManager CacheTransBufferManager( KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens = std::nullopt); - static size_t preAllocBufferSize(std::optional maxNumTokens = std::nullopt); + static size_t preAllocBufferSize(std::map const& cacheSizeBytesPerTokenPerWindow, + std::optional const& cacheTransceiverConfig = std::nullopt); std::optional assignBufferIndexForSend(); void freeBufferIndexForSend(std::optional bufferId); diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 3dd85b7dd4f..599a89cef03 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -62,41 +62,49 @@ std::unique_ptr CacheTransceiverFactory::createCacheTransc runtime::WorldConfig const& worldConfig, executor::kv_cache::CacheState::AttentionType attentionType, std::optional cacheTransceiverConfig) { - - std::optional commType; - if (common::getEnvUseUCXKvCache()) - { - commType = CacheTransceiver::CommType::UCX; - TLLM_LOG_INFO("Enable UCX KV cache transport."); - } - else if (common::getEnvUseNixlKvCache()) + if (!cacheTransceiverConfig.has_value() || !cacheTransceiverConfig.value().getBackendType().has_value()) { - commType = CacheTransceiver::CommType::NIXL; - TLLM_LOG_INFO("Enable NIXL KV cache transport."); + TLLM_LOG_INFO("CacheTransceiver is disabled."); + return nullptr; } - else if (common::getEnvUseMPIKvCache()) + auto backendType = cacheTransceiverConfig.value().getBackendType(); + if (backendType.value() == executor::CacheTransceiverConfig::BackendType::DEFAULT) { - commType = CacheTransceiver::CommType::MPI; - TLLM_LOG_INFO("Enable MPI KV cache transport."); + if (common::getEnvUseUCXKvCache()) + { + backendType = executor::CacheTransceiverConfig::BackendType::UCX; + TLLM_LOG_INFO("Enable UCX KV cache transport."); + } + else if (common::getEnvUseNixlKvCache()) + { + backendType = executor::CacheTransceiverConfig::BackendType::NIXL; + TLLM_LOG_INFO("Enable NIXL KV cache transport."); + } + else if (common::getEnvUseMPIKvCache()) + { + backendType = executor::CacheTransceiverConfig::BackendType::MPI; + TLLM_LOG_INFO("Enable MPI KV cache transport."); + TLLM_LOG_WARNING("MPI KV cache transport is deprecated, please use UCX or NIXL instead."); + } + else + { + backendType = executor::CacheTransceiverConfig::BackendType::UCX; + } } + cacheTransceiverConfig.value().setBackendType(backendType); - if (commType) - { - executor::kv_cache::CacheState::ModelConfig cacheStateCfg{ - modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()}; + executor::kv_cache::CacheState::ModelConfig cacheStateCfg{ + modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()}; - return std::make_unique(cacheManager, commType.value(), cacheStateCfg, worldConfig, - modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig); - } - return nullptr; + return std::make_unique( + cacheManager, cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig); } -CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, CommType commType, +CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType, std::optional cacheTransceiverConfig) - : mCommType{commType} - , mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) + : mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) , mCacheTransceiverConfig{cacheTransceiverConfig} { using tensorrt_llm::batch_manager::kv_cache_manager::CacheFormatter; @@ -138,59 +146,59 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa } } bool isMLA = attentionType == executor::kv_cache::CacheState::AttentionType::kMLA; - if (mCommType == CommType::MPI || mCommType == CommType::UCX || mCommType == CommType::NIXL) - { - std::optional maxNumTokens = std::nullopt; - if (mCacheTransceiverConfig.has_value()) - { - maxNumTokens = mCacheTransceiverConfig.value().getMaxNumTokens(); - } - mCacheTransBufferManager - = std::make_unique(cacheManager, maxNumTokens); - if (mCommType == CommType::UCX) - { - std::lock_guard lock(mDllMutex); - mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME); - TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly."); - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - TLLM_CHECK_WITH_INFO(ret != nullptr, - "Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not " - "built with UCX support, please rebuild in UCX-enabled environment."); - return ret; - }; - std::unique_ptr (*makeUcxConnectionManager)(); - *(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager"); - mManager = makeUcxConnectionManager(); - TLLM_LOG_INFO("UCX Connection Manager created"); - } - else if (mCommType == CommType::NIXL) - { - mManager = std::make_unique( - mCacheTransBufferManager.get()); - TLLM_LOG_INFO("NIXL Connection Manager created"); - } - else - { - mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world()); - mManager = std::make_unique(mMpiWorldComm); - TLLM_LOG_INFO("MPI Connection Manager created"); - } + TLLM_CHECK_WITH_INFO(mCacheTransceiverConfig.has_value(), "CacheTransceiverConfig is not set."); + auto backendType = mCacheTransceiverConfig.value().getBackendType(); + TLLM_CHECK_WITH_INFO( + backendType.has_value() && (backendType.value() != executor::CacheTransceiverConfig::BackendType::DEFAULT), + " CacheTransceiverConfig::BackendType is not set."); - using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter; - auto makeFormatter = [cacheManager, isMLA, this]() - { return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); }; + std::optional maxNumTokens = mCacheTransceiverConfig.value().getMaxTokensInBuffer(); - mDataResponder = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); - mDataRequester = std::make_unique( - std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + mCacheTransBufferManager = std::make_unique(cacheManager, maxNumTokens); + if (backendType.value() == executor::CacheTransceiverConfig::BackendType::UCX) + { + std::lock_guard lock(mDllMutex); + mWrapperLibHandle = dllOpen(UCX_WRAPPER_LIB_NAME); + TLLM_CHECK_WITH_INFO(mWrapperLibHandle != nullptr, "UCX wrapper library is not open correctly."); + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + TLLM_CHECK_WITH_INFO(ret != nullptr, + "Unable to load UCX wrapper library symbol, possible cause is that TensorRT-LLM library is not " + "built with UCX support, please rebuild in UCX-enabled environment."); + return ret; + }; + std::unique_ptr (*makeUcxConnectionManager)(); + *(void**) (&makeUcxConnectionManager) = load_sym(mWrapperLibHandle, "makeUcxConnectionManager"); + mManager = makeUcxConnectionManager(); + TLLM_LOG_INFO("UCX Connection Manager created"); + } + else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL) + { + mManager = std::make_unique( + mCacheTransBufferManager.get()); + TLLM_LOG_INFO("NIXL Connection Manager created"); + } + else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI) + { + mMpiWorldComm = std::addressof(tensorrt_llm::mpi::MpiComm::world()); + mManager = std::make_unique(mMpiWorldComm); + TLLM_LOG_INFO("MPI Connection Manager created"); } else { - TLLM_THROW("Unsupported communication type."); + TLLM_THROW("Unsupported cache transceiver backend type "); } + + using tensorrt_llm::batch_manager::kv_cache_manager::MLACacheFormatter; + auto makeFormatter = [cacheManager, isMLA, this]() + { return createCacheFormatter(cacheManager, mCacheTransBufferManager.get(), isMLA); }; + + mDataResponder = std::make_unique( + std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + mDataRequester = std::make_unique( + std::make_unique(mManager.get(), *mCacheState, worldConfig.getRank(), makeFormatter())); + initializeCommState(); } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 540dee9148b..ba3b2a94ede 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -2235,13 +2235,8 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfi cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken; } - auto const extraCostMemoryBytes = extraCostMemory - * std::accumulate(cacheSizeBytesPerTokenPerWindow.cbegin(), cacheSizeBytesPerTokenPerWindow.cend(), - SizeType32{0}, [](SizeType32 acc, auto const cost) { return acc + cost.second; }); - - TLLM_LOG_DEBUG( - "extraCostMemoryBytes [all windows] [Gib]: %0.2f", extraCostMemoryBytes / static_cast(1 << 30)); - + TLLM_LOG_DEBUG("extraCostMemory [Gib]: %0.2f", extraCostMemory / static_cast(1 << 30)); + allottedPrimaryMemBytes = allottedPrimaryMemBytes - extraCostMemory; auto const tokensPerBlock = modelConfig.getTokensPerBlock(); auto const calculatePrimaryBlocks = [&](SizeType32 windowSize, float windowSizeShare, SizeType32 cacheSizeBytesPerToken) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 1bc80ac2156..b36f0856fd5 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -264,10 +264,35 @@ TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptr const& maxAttentionWindowVec, bool isCrossAttention, SizeType32 kvFactor) + { + auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange( + worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank(), isCrossAttention); + auto numKvHeadsPerLayer = std::vector(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd); + auto windowSizeLayers + = BaseKVCacheManager::groupLayersByWindowSize(maxAttentionWindowVec, modelConfig.getNbLayers()); + std::map cacheSizeBytesPerTokenPerWindow; + for (auto const& [windowSize, managedLayers] : windowSizeLayers) + { + auto const cacheSizePerToken = BaseKVCacheManager::calculateCacheSizePerTokenForSingleWindowSize( + modelConfig, managedLayers, isCrossAttention, kvFactor); + auto const cacheSizeBytesPerToken + = cacheSizePerToken * BufferDataType(modelConfig.getKvDataType()).getSize(); + cacheSizeBytesPerTokenPerWindow[windowSize] = cacheSizeBytesPerToken; + } + + return cacheSizeBytesPerTokenPerWindow; + }; auto cacheTransceiverConfig = executorConfig.getCacheTransceiverConfig().value_or(executor::CacheTransceiverConfig()); - auto cacheTransPreAllocaSize - = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize(cacheTransceiverConfig.getMaxNumTokens()); + + auto const cacheSizeBytesPerTokenPerWindow = calculateCacheSizePerToken( + mModelConfig, mWorldConfig, getMaxAttentionWindowVec(), mModelConfig.useCrossAttention(), 2); + auto cacheTransPreAllocaSize = kv_cache_manager::CacheTransBufferManager::preAllocBufferSize( + cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); auto const [freePrimaryMemBytes, freeSecondaryMemBytes] = BaseKVCacheManager::calculateFreeMemBytes(mRuntime->getBufferManager(), kvCacheConfig); @@ -879,8 +904,9 @@ void TrtGptModelInflightBatching::forwardSync() { // TODO: skip if sending layer-wise { - TLLM_CHECK_WITH_INFO( - mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration."); + TLLM_CHECK_WITH_INFO(mCacheTransceiver, + "Disaggregated serving is not enabled, please check the configuration of " + "cacheTransceiverConfig."); mCacheTransceiver->respondAndSendAsync(llmReq.get()); } mSeqSlotManager->freeSequenceSlot(llmReq->mRequestId); @@ -1780,8 +1806,8 @@ void TrtGptModelInflightBatching::executeStep( bufferCast(*mBuffers[bufferId]->transformerBuffers->contextProgressHost)[0] = progress.get(); if (progress) { - TLLM_CHECK_WITH_INFO( - mCacheTransceiver, "Disaggregated serving is not enabled, please check the configuration."); + TLLM_CHECK_WITH_INFO(mCacheTransceiver, + "Disaggregated serving is not enabled, please check the configuration of cacheTransceiverConfig."); mCacheTransceiver->respondAndSendLayerWise(layerWiseRequests, progress); } } diff --git a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp index 1f392ef0583..6919d213642 100644 --- a/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp +++ b/cpp/tensorrt_llm/executor/cacheTransceiverConfig.cpp @@ -21,24 +21,36 @@ namespace tensorrt_llm::executor { -CacheTransceiverConfig::CacheTransceiverConfig(std::optional maxNumTokens) - : mMaxNumTokens(maxNumTokens) +CacheTransceiverConfig::CacheTransceiverConfig( + std::optional backendType, std::optional maxNumTokens) + : mBackendType(backendType) + , mMaxTokensInBuffer(maxNumTokens) { } bool CacheTransceiverConfig::operator==(CacheTransceiverConfig const& other) const { - return mMaxNumTokens == other.mMaxNumTokens; + return mMaxTokensInBuffer == other.mMaxTokensInBuffer && mBackendType == other.mBackendType; } -std::optional CacheTransceiverConfig::getMaxNumTokens() const +void CacheTransceiverConfig::setBackendType(std::optional backendType) { - return mMaxNumTokens; + mBackendType = backendType; } -void CacheTransceiverConfig::setMaxNumTokens(size_t maxNumTokens) +void CacheTransceiverConfig::setMaxTokensInBuffer(std::optional maxTokensInBuffer) { - mMaxNumTokens = maxNumTokens; + mMaxTokensInBuffer = maxTokensInBuffer; +} + +std::optional CacheTransceiverConfig::getBackendType() const +{ + return mBackendType; +} + +std::optional CacheTransceiverConfig::getMaxTokensInBuffer() const +{ + return mMaxTokensInBuffer; } } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 2ea6c26dc73..65718f0405d 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -1258,19 +1258,22 @@ size_t Serialization::serializedSize(SchedulerConfig const& schedulerConfig) // CacheTransceiverConfig CacheTransceiverConfig Serialization::deserializeCacheTransceiverConfig(std::istream& is) { - auto maxNumTokens = su::deserialize>(is); - return CacheTransceiverConfig{maxNumTokens}; + auto backendType = su::deserialize>(is); + auto maxTokensInBuffer = su::deserialize>(is); + return CacheTransceiverConfig{backendType, maxTokensInBuffer}; } void Serialization::serialize(CacheTransceiverConfig const& cacheTransceiverConfig, std::ostream& os) { - su::serialize(cacheTransceiverConfig.getMaxNumTokens(), os); + su::serialize(cacheTransceiverConfig.getBackendType(), os); + su::serialize(cacheTransceiverConfig.getMaxTokensInBuffer(), os); } size_t Serialization::serializedSize(CacheTransceiverConfig const& cacheTransceiverConfig) { size_t totalSize = 0; - totalSize += su::serializedSize(cacheTransceiverConfig.getMaxNumTokens()); + totalSize += su::serializedSize(cacheTransceiverConfig.getBackendType()); + totalSize += su::serializedSize(cacheTransceiverConfig.getMaxTokensInBuffer()); return totalSize; } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp index 87b0a26a79e..d92336e6bdf 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -80,21 +81,15 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) .def("check_gen_transfer_status", &BaseCacheTransceiver::checkGenTransferStatus) .def("check_gen_transfer_complete", &BaseCacheTransceiver::checkGenTransferComplete); - py::enum_(m, "CommType") - .value("UNKNOWN", tb::CacheTransceiver::CommType::UNKNOWN) - .value("MPI", tb::CacheTransceiver::CommType::MPI) - .value("UCX", tb::CacheTransceiver::CommType::UCX) - .value("NIXL", tb::CacheTransceiver::CommType::NIXL); - py::enum_(m, "AttentionType") .value("DEFAULT", executor::kv_cache::CacheState::AttentionType::kDEFAULT) .value("MLA", executor::kv_cache::CacheState::AttentionType::kMLA); py::classh(m, "CacheTransceiver") - .def(py::init, SizeType32, SizeType32, runtime::WorldConfig, nvinfer1::DataType, - executor::kv_cache::CacheState::AttentionType, std::optional>(), - py::arg("cache_manager"), py::arg("comm_type"), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), + .def(py::init, SizeType32, SizeType32, + runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType, + std::optional>(), + py::arg("cache_manager"), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), py::arg("tokens_per_block"), py::arg("world_config"), py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt); @@ -102,5 +97,5 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) .def(py::init>(), py::arg("cache_manager"), py::arg("max_num_tokens") = std::nullopt) .def_static("pre_alloc_buffer_size", &tb::kv_cache_manager::CacheTransBufferManager::preAllocBufferSize, - py::arg("max_num_tokens") = std::nullopt); + py::arg("cache_size_bytes_per_token_per_window"), py::arg("cache_transceiver_config") = py::none()); } diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index 71a0b4af724..bc0d997e337 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -407,21 +407,46 @@ void initConfigBindings(pybind11::module_& m) "stop_token_ids", &tle::GuidedDecodingConfig::getStopTokenIds, &tle::GuidedDecodingConfig::setStopTokenIds) .def(py::pickle(guidedDecodingConfigGetstate, guidedDecodingConfigSetstate)); - auto cacheTransceiverConfigGetstate - = [](tle::CacheTransceiverConfig const& self) { return py::make_tuple(self.getMaxNumTokens()); }; + auto cacheTransceiverConfigGetstate = [](tle::CacheTransceiverConfig const& self) + { return py::make_tuple(self.getBackendType(), self.getMaxTokensInBuffer()); }; auto cacheTransceiverConfigSetstate = [](py::tuple const& state) { - if (state.size() != 1) + if (state.size() != 2) { throw std::runtime_error("Invalid CacheTransceiverConfig state!"); } - return tle::CacheTransceiverConfig(state[0].cast>()); + return tle::CacheTransceiverConfig( + state[0].cast(), state[1].cast>()); }; + py::enum_(m, "CacheTransceiverBackendType") + .value("DEFAULT", tle::CacheTransceiverConfig::BackendType::DEFAULT) + .value("MPI", tle::CacheTransceiverConfig::BackendType::MPI) + .value("UCX", tle::CacheTransceiverConfig::BackendType::UCX) + .value("NIXL", tle::CacheTransceiverConfig::BackendType::NIXL) + .def(py::init( + [](std::string const& str) + { + if (str == "DEFAULT" || str == "default") + return tle::CacheTransceiverConfig::BackendType::DEFAULT; + if (str == "MPI" || str == "mpi") + return tle::CacheTransceiverConfig::BackendType::MPI; + if (str == "UCX" || str == "ucx") + return tle::CacheTransceiverConfig::BackendType::UCX; + if (str == "NIXL" || str == "nixl") + return tle::CacheTransceiverConfig::BackendType::NIXL; + throw std::runtime_error("Invalid backend type: " + str); + })); + + py::implicitly_convertible(); + py::class_(m, "CacheTransceiverConfig") - .def(py::init>(), py::arg("max_num_tokens") = py::none()) - .def_property("max_num_tokens", &tle::CacheTransceiverConfig::getMaxNumTokens, - &tle::CacheTransceiverConfig::setMaxNumTokens) + .def(py::init, std::optional>(), + py::arg("backend") = std::nullopt, py::arg("max_tokens_in_buffer") = std::nullopt) + .def_property( + "backend", &tle::CacheTransceiverConfig::getBackendType, &tle::CacheTransceiverConfig::setBackendType) + .def_property("max_tokens_in_buffer", &tle::CacheTransceiverConfig::getMaxTokensInBuffer, + &tle::CacheTransceiverConfig::setMaxTokensInBuffer) .def(py::pickle(cacheTransceiverConfigGetstate, cacheTransceiverConfigSetstate)); auto executorConfigGetState = [](py::object const& self) diff --git a/cpp/tests/executor/disaggExecutorTest.cpp b/cpp/tests/executor/disaggExecutorTest.cpp index 49c8c00f048..75ab6dccb44 100644 --- a/cpp/tests/executor/disaggExecutorTest.cpp +++ b/cpp/tests/executor/disaggExecutorTest.cpp @@ -662,6 +662,8 @@ TEST_P(DisaggParamsTest, DisaggTokenComparison) KvCacheConfig kvCacheConfig{true, std::nullopt, std::nullopt, std::nullopt, freeGpuMemoryFraction}; executorConfig.setKvCacheConfig(kvCacheConfig); executorConfig.setRequestStatsMaxIterations(1000); + executorConfig.setCacheTransceiverConfig( + texec::CacheTransceiverConfig(texec::CacheTransceiverConfig::BackendType::DEFAULT)); auto manager = tr::BufferManager(std::make_shared()); auto const& givenInput = tr::utils::loadNpy(manager, inputPath.string(), tr::MemoryType::kCPU); auto [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths(*givenInput, modelIds.padId); @@ -894,6 +896,8 @@ TEST_P(DisaggOrchestratorParamsTest, DisaggTokenComparison) spawnProcess ? std::nullopt : std::optional>(participantIdsEachInstance.at(in)), orchestratorConfig}; executorConfig.setParallelConfig(parallelConfig); + executorConfig.setCacheTransceiverConfig( + texec::CacheTransceiverConfig(texec::CacheTransceiverConfig::BackendType::DEFAULT)); if (in < contextNum) { ctxExecutorConfigs.push_back(executorConfig); @@ -994,6 +998,8 @@ TEST_P(ConditionalDisaggParamsTest, DisaggTokenComparison) KvCacheConfig kvCacheConfig{true, std::nullopt, std::nullopt, std::nullopt, freeGpuMemoryFraction}; executorConfig.setKvCacheConfig(kvCacheConfig); executorConfig.setRequestStatsMaxIterations(1000); + executorConfig.setCacheTransceiverConfig( + texec::CacheTransceiverConfig(CacheTransceiverConfig::BackendType::DEFAULT)); auto manager = tr::BufferManager(std::make_shared()); auto const& givenInput = tr::utils::loadNpy(manager, inputPath.string(), tr::MemoryType::kCPU); auto [givenInputLengths, nbGivenInputs, maxInputLength] = getGivenInputLengths(*givenInput, modelIds.padId); diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index 996b7b97237..27e1590e6a2 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -18,6 +18,7 @@ #include "tensorrt_llm/batch_manager/cacheTransBuffer.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/iTensor.h" #include @@ -110,8 +111,13 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize) size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend() ? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum() : 1; - size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize(maxNumTokens) - * kvCacheSizePerToken(4, 2, 64, CacheType::kSELFKONLY); + size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELFKONLY); + std::map cacheSizeBytesPerTokenPerWindow{ + {maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}}; + tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{ + tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens}; + size_t bufferSizeBytes + = CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); auto bufferId = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId.has_value()); EXPECT_EQ(bufferId.value(), 0); @@ -149,15 +155,18 @@ TEST_F(CacheTransBufferTest, TestPreAllocBufferSize2) size_t sendBufferCount = tensorrt_llm::common::getEnvParallelCacheSend() ? tensorrt_llm::common::getEnvKVCacheSendMaxConcurrenceNum() : 1; - size_t bufferSizeBytes = CacheTransBufferManager::preAllocBufferSize(maxNumTokens) - * kvCacheSizePerToken(4, 2, 64, CacheType::kSELF); + size_t cacheSizeBytesPerToken = kvCacheSizePerToken(4, 2, 64, CacheType::kSELF); + tensorrt_llm::executor::CacheTransceiverConfig cacheTransceiverConfig{ + tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, maxNumTokens}; + std::map cacheSizeBytesPerTokenPerWindow{ + {maxBlocksPerSeq * tokensPerBlock, cacheSizeBytesPerToken}}; + size_t bufferSizeBytes + = CacheTransBufferManager::preAllocBufferSize(cacheSizeBytesPerTokenPerWindow, cacheTransceiverConfig); auto bufferId = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId.has_value()); EXPECT_EQ(bufferId.value(), 0); EXPECT_EQ(bufferSizeBytes, mTransBufferManager->getSendBuffer(bufferId)->getSizeInBytes() * (recvbufferCount + sendBufferCount)); - TLLM_LOG_INFO("bufferSizeBytes: %ld , getSizeINBytes: %ld", bufferSizeBytes, - mTransBufferManager->getSendBuffer(bufferId)->getSizeInBytes() * (recvbufferCount + sendBufferCount)); mTransBufferManager->freeBufferIndexForSend(bufferId); exit(testing::Test::HasFailure() ? 1 : 0); } diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index d29cf0350ca..18f7e6f5379 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -785,8 +785,8 @@ TEST(SerializeUtilsTest, ExecutorConfig) texec::SpeculativeDecodingConfig(true), texec::GuidedDecodingConfig( texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, std::initializer_list{"eos"}), - std::vector{tensorrt_llm::executor::AdditionalModelOutput{"output_name"}}, texec::CacheTransceiverConfig(1024), - true, true, true); + std::vector{tensorrt_llm::executor::AdditionalModelOutput{"output_name"}}, + texec::CacheTransceiverConfig(std::nullopt, 1024), true, true, true); auto executorConfig2 = serializeDeserialize(executorConfig); EXPECT_EQ(executorConfig.getMaxBeamWidth(), executorConfig2.getMaxBeamWidth()); @@ -862,7 +862,9 @@ TEST(SerializeUtilsTest, MethodReturnType) TEST(SerializeUtilsTest, CacheTransceiverConfig) { - texec::CacheTransceiverConfig cacheTransceiverConfig(1024); + texec::CacheTransceiverConfig cacheTransceiverConfig( + tensorrt_llm::executor::CacheTransceiverConfig::BackendType::UCX, 1024); auto cacheTransceiverConfig2 = serializeDeserialize(cacheTransceiverConfig); - EXPECT_EQ(cacheTransceiverConfig.getMaxNumTokens(), cacheTransceiverConfig2.getMaxNumTokens()); + EXPECT_EQ(cacheTransceiverConfig.getBackendType(), cacheTransceiverConfig2.getBackendType()); + EXPECT_EQ(cacheTransceiverConfig.getMaxTokensInBuffer(), cacheTransceiverConfig2.getMaxTokensInBuffer()); } diff --git a/docs/source/advanced/disaggregated-service.md b/docs/source/advanced/disaggregated-service.md index 757b1da81f4..426d327c18b 100644 --- a/docs/source/advanced/disaggregated-service.md +++ b/docs/source/advanced/disaggregated-service.md @@ -16,8 +16,6 @@ An [architectural and performance overview](../../../docs/source/blogs/tech_blog TRT-LLM uses some environment variables to control the behavior of disaggregated service. -* `TRTLLM_USE_UCX_KVCACHE`: Specifies whether to use UCX for KV cache transfer. The default value is `0`. This must be enabled when using a disaggregated service. - * `TRTLLM_PARALLEL_CACHE_SEND`: If set to `1`, contextExecutor will attempt to send KV cache for multiple requests in parallel. The default value is `0`. * `TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP`: If set to `1`, generationExecutor will not overlap KV cache transfer with model inference. The default value is `0`. @@ -66,55 +64,19 @@ A. Yes, it's recommended that different executor use different GPUs . We support *Q. How to handle error `Disaggregated serving is not enabled, please check the configuration?`* -A. Please set the environment variables -``` -export TRTLLM_USE_UCX_KVCACHE=1 -``` +A. please set `backendType` of `CacheTransceiverConfig`. +```cpp +ExecutorConfig executorConfig{...}; -*Q. Why do some profiling tools show that TRT-LLM's KV cache transfer does not utilize NVLink even on devices equipped with NVLink?* +executorConfig.setCacheTransceiverConfig(texec::CacheTransceiverConfig(BackendType::DEFAULT)); +``` -A. Please check version of `UCX` with `ucx_info -v`. -If the version of UCX <=1.17, set the environment variables `UCX_RNDV_FRAG_MEM_TYPE=cuda` and `UCX_MEMTYPE_CACHE=n` to enable NVLink. For BlackWell architecture GPUs, UCX version >=1.19 is required to enable NVLink. -If the version of UCX >=1.18, there are several ways to enable NVLink: -1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda`, `UCX_CUDA_COPY_DMABUF=no`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`. -2. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`. $Size represents the size of the buffer for KV cache transfer, which is recommended to be larger than the size of the KV cache for the longest request. +When the environment variable `TRTLLM_USE_MPI_KVCACHE=1` is set, TRT-LLM will transfer the KV cache using `CUDA-aware MPI`. All executor processes involved must share the same MPI world communicator. Consequently, with `TRTLLM_USE_MPI_KVCACHE=1`, TRT-LLM only supports launching multiple executors via `MPI`. Additionally, the `CommunicationMode` for the executors must be set to `kLEADER` or `kORCHESTRATOR` with `SpawnProcesses=false` for the `disaggregated-service`. These restrictions do not apply when `TRTLLM_USE_UCX_KVCACHE=1` is set. *Q. Does TRT-LLM support using GPU direct RDMA for inter-node KV Cache transfer?* -A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer, but it is not enabled by default. There are several ways to enable GPU direct RDMA: -1. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B`,`UCX_RNDV_FRAG_MEM_TYPE=cuda`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`. -2. Set the environment variables `TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size`, `UCX_MEMTYPE_CACHE=n` and `UCX_RNDV_PIPELINE_ERROR_HANDLING=y`, $Size represents the size of the buffer for KV cache transfer, which is recommended to be larger than the size of the KV cache for the longest request. - -*Q. Are there any guidelines for performance tuning of KV cache transfer?* - -A. Depending on the user's use case, certain sets of environment variables can help avoid poor KV cache transfer performance. - -Environment Variable Set A - -``` -export TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B -export UCX_RNDV_FRAG_MEM_TYPES=cuda -export UCX_MEMTYPE_CACHE=n -export UCX_RNDV_PIPELINE_ERROR_HANDLING=y -``` -This set allows KV cache transfers to utilize NVLink within nodes and GDRDMA between nodes. - -Environment Variable Set B - -``` -export TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=0B -export UCX_CUDA_COPY_ASYNC_MEM_TYPE=cuda -export UCX_CUDA_COPY_DMABUF=no -export UCX_MEMTYPE_CACHE=n -export UCX_RNDV_PIPELINE_ERROR_HANDLING=y -``` -Set B may provide slightly better performance on a single node compared to Set A. However, when transferring KV cache across multiple nodes, it may cause program instability. +A. Yes, TRT-LLM supports using GPU direct RDMA for inter-node KV cache transfer. -Environment Variable Set C +*Q. What causes the substantial bandwidth fluctuations in kvCache transfers, especially during the first few requests following service initialization?* -``` -export TRTLLM_KVCACHE_TRANSFER_BUFFER_SIZE=$Size -export UCX_MEMTYPE_CACHE=n -export UCX_RNDV_PIPELINE_ERROR_HANDLING=y -``` -Set C can achieve better performance than Sets A and B, both within and between nodes. However, if the KV cache size exceeds the specified $Size, performance may degrade. +A. The communication for kvCache transfer between executors are established dynamically. The connection establishment process incurs significant overhead, which explains the apparently lower kvCache transfer bandwidth observed during the initial requests after service startup. This lower bandwidth reflects the inclusion of connection establishment overhead. When conducting benchmarks, it is recommended to perform a warm-up phase to ensure accurate performance measurements. diff --git a/docs/source/scripts/disaggregated/gen_yaml.py b/docs/source/scripts/disaggregated/gen_yaml.py index 1d198a9766d..859a07310ab 100644 --- a/docs/source/scripts/disaggregated/gen_yaml.py +++ b/docs/source/scripts/disaggregated/gen_yaml.py @@ -176,7 +176,8 @@ def gen_config_file(config_path: str, 'disable_overlap_scheduler': True, 'kv_cache_dtype': 'fp8', 'cache_transceiver_config': { - 'max_num_tokens': 8320, + 'backend': 'default', + 'max_tokens_in_buffer': 8320, }, }, 'generation_servers': { @@ -199,7 +200,8 @@ def gen_config_file(config_path: str, 'backend': 'TRTLLM', }, 'cache_transceiver_config': { - 'max_num_tokens': 8320, + 'backend': 'default', + 'max_tokens_in_buffer': 8320, }, } } diff --git a/examples/disaggregated/README.md b/examples/disaggregated/README.md index 120706dd01a..13abb8c73d6 100644 --- a/examples/disaggregated/README.md +++ b/examples/disaggregated/README.md @@ -4,14 +4,25 @@ To run TRT-LLM in disaggregated mode, you must first launch context (prefill) an ## Launching context and generation servers using multiple independent `trtllm-serve` commands +We use the `cache_transceiver_config` configuration to set up disaggregated serving, which includes the following parameters: + +``` +cache_transceiver_config: + backend: + max_tokens_in_buffer: +``` + +`backend` specifies the communication backend for transferring the kvCache, valid options include `DEFAULT`,`UCX`, `NIXL`, and `MPI`, the default backend is UCX. + +`max_tokens_in_buffer` defines the buffer size for kvCache transfers, it is recommended to set this value greater than or equal to the maximum ISL (Input Sequence Length) of all requests for optimal performance. + You can use multiple `trtllm-serve` commands to launch the context and generation servers that will be used for disaggregated serving. For example, you could launch two context servers and one generation servers as follows: ``` -echo -e "disable_overlap_scheduler: True\ncache_transceiver_config:\n max_num_tokens: 2048" > context_extra-llm-api-config.yml -echo -e "cache_transceiver_config:\n max_num_tokens: 2048" > gen_extra-llm-api-config.yml +echo -e "disable_overlap_scheduler: True\ncache_transceiver_config:\n backend: UCX\n max_tokens_in_buffer: 2048" > context_extra-llm-api-config.yml +echo -e "cache_transceiver_config:\n backend: UCX\n max_tokens_in_buffer: 2048" > gen_extra-llm-api-config.yml -export TRTLLM_USE_UCX_KVCACHE=1 #Context servers CUDA_VISIBLE_DEVICES=0 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8001 --backend pytorch --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_0 & CUDA_VISIBLE_DEVICES=1 trtllm-serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --host localhost --port 8002 --backend pytorch --extra_llm_api_options ./context_extra-llm-api-config.yml &> log_ctx_1 & @@ -128,6 +139,8 @@ context_servers: pipeline_parallel_size: 1 kv_cache_config: free_gpu_memory_fraction: 0.9 + cache_transceiver_config: + backend: UCX urls: - "localhost:8001" - "localhost:8002" @@ -135,6 +148,8 @@ generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: UCX urls: - "localhost:8003" ``` @@ -143,3 +158,7 @@ Once the context and generation servers are launched, you can again launch the d ``` trtllm-serve disaggregated -c disagg_config.yaml ``` + +## Know Issues + +The MPI communication backend for kvCache transfer has been deprecated and may not be supported in the future. When using the MPI backend, the environment variable `TRTLLM_USE_MPI_KVCACHE=1` should be set to avoid conflicts between mpi4py and kvCache transfer. diff --git a/examples/disaggregated/disagg_config.yaml b/examples/disaggregated/disagg_config.yaml index 6d5314f235c..ae72c1b074e 100644 --- a/examples/disaggregated/disagg_config.yaml +++ b/examples/disaggregated/disagg_config.yaml @@ -10,11 +10,15 @@ context_servers: pipeline_parallel_size: 1 kv_cache_config: free_gpu_memory_fraction: 0.2 + cache_transceiver_config: + backend: "default" urls: - "localhost:8001" generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "default" urls: - "localhost:8002" diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index a7db4910b78..37a82df323b 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -2,6 +2,7 @@ from os import getenv import tensorrt_llm +from tensorrt_llm import logger from tensorrt_llm.bindings import WorldConfig from tensorrt_llm.bindings.executor import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping @@ -10,9 +11,9 @@ from .resource_manager import KVCacheManager CacheTransceiverCpp = tensorrt_llm.bindings.internal.batch_manager.CacheTransceiver -CommTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CommType AttentionTypeCpp = tensorrt_llm.bindings.internal.batch_manager.AttentionType CacheTransBufferManagerCpp = tensorrt_llm.bindings.internal.batch_manager.CacheTransBufferManager +BackendTypeCpp = tensorrt_llm.bindings.executor.CacheTransceiverBackendType def mapping_to_world_config(mapping: Mapping) -> WorldConfig: @@ -30,21 +31,27 @@ def create_kv_cache_transceiver( mapping: Mapping, kv_cache_manager: KVCacheManager, attention_type: AttentionTypeCpp, cache_transceiver_config: CacheTransceiverConfig): - - comm_type = None - if getenv("TRTLLM_USE_UCX_KVCACHE"): - comm_type = CommTypeCpp.UCX - elif getenv("TRTLLM_USE_NIXL_KVCACHE"): - comm_type = CommTypeCpp.NIXL - elif getenv("TRTLLM_USE_MPI_KVCACHE"): - comm_type = CommTypeCpp.MPI - - cache_transceiver = None - if comm_type is not None: - cache_transceiver = BindKvCacheTransceiver(mapping, comm_type, - kv_cache_manager, - attention_type, - cache_transceiver_config) + if cache_transceiver_config is None or (cache_transceiver_config.backend + is None): + logger.info("cache_transceiver is disabled") + return None + if (cache_transceiver_config.backend == BackendTypeCpp.DEFAULT): + + backend_type = BackendTypeCpp.UCX + if getenv("TRTLLM_USE_UCX_KVCACHE"): + backend_type = BackendTypeCpp.UCX + elif getenv("TRTLLM_USE_NIXL_KVCACHE"): + backend_type = BackendTypeCpp.NIXL + elif getenv("TRTLLM_USE_MPI_KVCACHE"): + backend_type = BackendTypeCpp.MPI + cache_transceiver_config.backend = backend_type + + if (cache_transceiver_config.backend == BackendTypeCpp.MPI): + logger.warning( + "MPI CacheTransceiver is deprecated, UCX or NIXL is recommended") + cache_transceiver = BindKvCacheTransceiver(mapping, kv_cache_manager, + attention_type, + cache_transceiver_config) return cache_transceiver @@ -78,8 +85,7 @@ def check_gen_transfer_complete(self): class BindKvCacheTransceiver(KvCacheTransceiver): - def __init__(self, mapping: Mapping, comm_type: CommTypeCpp, - kv_cache_manager: KVCacheManager, + def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager, attention_type: AttentionTypeCpp, cache_transceiver_config: CacheTransceiverConfig): world_config = mapping_to_world_config(mapping) @@ -88,7 +94,7 @@ def __init__(self, mapping: Mapping, comm_type: CommTypeCpp, tokens_per_block = kv_cache_manager.tokens_per_block dtype = kv_cache_manager.dtype - self.impl = CacheTransceiverCpp(kv_cache_manager.impl, comm_type, + self.impl = CacheTransceiverCpp(kv_cache_manager.impl, num_kv_heads_per_layer, head_dim, tokens_per_block, world_config, dtype, attention_type, @@ -120,7 +126,7 @@ def __init__(self, kv_cache_manager: KVCacheManager, max_num_tokens: int): max_num_tokens) @staticmethod - def pre_alloc_buffer_size(max_num_tokens: int, - kv_cache_size_per_token: int): + def pre_alloc_buffer_size(kv_cache_size_per_token: int, + cache_transceiver_config: CacheTransceiverConfig): return CacheTransBufferManagerCpp.pre_alloc_buffer_size( - max_num_tokens) * kv_cache_size_per_token + kv_cache_size_per_token, cache_transceiver_config) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index c8518c83a81..74c754651d1 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1346,6 +1346,8 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]: # In disaggregated serving, we might get either context request or # generation request. In IFB, we only get context request from request queue + # In IFB, we only get context request from request queue + if self.kv_cache_transceiver: for req_item in new_requests_cur_rank: if req_item.request.request_type == RequestType.REQUEST_TYPE_CONTEXT_ONLY: diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index ddbcba2a115..35357e658a8 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -429,7 +429,6 @@ def disaggregated_mpi_worker(config_file: Optional[str], log_level: str): disagg_cfg.server_configs) logger.set_level(log_level) - os.environ['TRTLLM_USE_MPI_KVCACHE'] = "1" set_mpi_comm(sub_comm) logger.info( f"mpi_session is provided for LLM instance. Global MPI rank: {global_mpi_rank()}, sub-comm MPI rank: {mpi_rank()}" diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index a82d0d71e5f..68fa336db89 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -406,6 +406,10 @@ def _enqueue_request(self, request: GenerationRequest) -> int: context_phase_params = None request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION if request.disaggregated_params is not None: + assert ( + not self._is_pytorch_backend + or self.engine.kv_cache_transceiver is not None + ), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:` in config file for disaggregated serving" request_type = request.disaggregated_params.get_request_type() if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY: context_phase_params = request.disaggregated_params.get_context_phase_params( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 111d779ef39..27fff5ef13e 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -879,12 +879,20 @@ class CacheTransceiverConfig(BaseModel, PybindMirror): """ Configuration for the cache transceiver. """ - max_num_tokens: Optional[int] = Field( + + backend: Optional[Literal["default", "ucx", "nixl", "mpi"]] = Field( + default=None, + description= + "The communication backend type to use for the cache transceiver.") + + max_tokens_in_buffer: Optional[int] = Field( default=None, description="The max number of tokens the transfer buffer can fit.") def _to_pybind(self): - return _CacheTransceiverConfig(max_num_tokens=self.max_num_tokens) + return _CacheTransceiverConfig( + backend=self.backend, + max_tokens_in_buffer=self.max_tokens_in_buffer) @dataclass diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 67915d0728f..fee38e723e6 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -195,6 +195,8 @@ def test_auto_dtype(self, disable_overlap_scheduler): gen_server_config = { "disable_overlap_scheduler": disable_overlap_scheduler } + ctx_server_config["cache_transceiver_config"] = {"backend": "default"} + gen_server_config["cache_transceiver_config"] = {"backend": "default"} disaggregated_server_config = { "hostname": "localhost", "port": 8000, @@ -232,11 +234,17 @@ def test_ngram(self): ctx_server_config = { "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, + "cache_transceiver_config": { + "backend": "default" + } } gen_server_config = { "disable_overlap_scheduler": True, "speculative_config": speculative_decoding_config, "kv_cache_config": kv_cache_config, + "cache_transceiver_config": { + "backend": "default" + } } disaggregated_server_config = { "hostname": "localhost", @@ -274,13 +282,19 @@ def test_eagle3(self, overlap_scheduler): "disable_overlap_scheduler": True, "speculative_config": speculative_decoding_config, "kv_cache_config": kv_cache_config, - "max_num_tokens": 13393 * 2 + "max_num_tokens": 13393 * 2, + "cache_transceiver_config": { + "backend": "default" + } } gen_server_config = { "disable_overlap_scheduler": not overlap_scheduler, "speculative_config": speculative_decoding_config, "kv_cache_config": kv_cache_config, - "max_num_tokens": 13393 * 2 + "max_num_tokens": 13393 * 2, + "cache_transceiver_config": { + "backend": "default" + } } disaggregated_server_config = { "hostname": "localhost", @@ -312,6 +326,8 @@ class TestLlama4ScoutInstruct(LlmapiAccuracyTestHarness): def test_auto_dtype(self, overlap_scheduler): ctx_server_config = {"disable_overlap_scheduler": True} gen_server_config = {"disable_overlap_scheduler": overlap_scheduler} + ctx_server_config["cache_transceiver_config"] = {"backend": "default"} + gen_server_config["cache_transceiver_config"] = {"backend": "default"} disaggregated_server_config = { "hostname": "localhost", "port": 8000, @@ -347,6 +363,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): def test_auto_dtype(self, overlap_scheduler, mtp_nextn): ctx_server_config = {"disable_overlap_scheduler": True} gen_server_config = {"disable_overlap_scheduler": not overlap_scheduler} + ctx_server_config["cache_transceiver_config"] = {"backend": "default"} + gen_server_config["cache_transceiver_config"] = {"backend": "default"} if mtp_nextn > 0: ctx_server_config["speculative_config"] = { "decoding_type": "MTP", @@ -389,11 +407,17 @@ class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): def test_auto_dtype(self, overlap_scheduler): ctx_server_config = { "disable_overlap_scheduler": True, - "cuda_graph_config": None + "cuda_graph_config": None, + "cache_transceiver_config": { + "backend": "default" + } } gen_server_config = { "disable_overlap_scheduler": overlap_scheduler, - "cuda_graph_config": None + "cuda_graph_config": None, + "cache_transceiver_config": { + "backend": "default" + } } ctx_server_config["kv_cache_config"] = { "max_attention_window": [512, 512, 512, 512, 512, 32768], diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml index cb776b0f258..6db8a0f1a93 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance.yaml @@ -20,6 +20,8 @@ context_servers: enable_partial_reuse: False event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" - "localhost:8002" @@ -32,6 +34,8 @@ generation_servers: max_seq_len: 4096 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default kv_cache_config: enable_block_reuse: True enable_partial_reuse: False diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml index edb7d62ba00..cc275b98c7c 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_aware_balance_deepseek_v3.yaml @@ -16,6 +16,8 @@ context_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 + cache_transceiver_config: + backend: "default" urls: - "localhost:8001" - "localhost:8002" @@ -30,6 +32,8 @@ generation_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.1 + cache_transceiver_config: + backend: "default" urls: - "localhost:8003" - "localhost:8004" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml index 30662441dbd..86da31c42bf 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse.yaml @@ -14,6 +14,8 @@ context_servers: enable_block_reuse: True enable_partial_reuse: True event_buffer_max_size: 1024 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -27,5 +29,7 @@ generation_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.05 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml index 4bcca2967bb..e76a253c1ae 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cache_reuse_deepseek_v3.yaml @@ -14,6 +14,8 @@ context_servers: enable_block_reuse: True enable_partial_reuse: True event_buffer_max_size: 1024 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -27,5 +29,7 @@ generation_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.05 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml index daf3c286d7c..2292fe22aaf 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional.yaml @@ -17,6 +17,8 @@ context_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -30,5 +32,7 @@ generation_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml index 59e713ad91a..345a958fa5e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_conditional_deepseek_v3.yaml @@ -17,6 +17,8 @@ context_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -30,5 +32,7 @@ generation_servers: enable_partial_reuse: True event_buffer_max_size: 1024 free_gpu_memory_fraction: 0.15 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml index d62a9c42cd9..1f63caed57f 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite.yaml @@ -9,11 +9,15 @@ context_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml index 4286a58eef8..97c03fbbcb1 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp.yaml @@ -13,6 +13,8 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 enable_attention_dp: true + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -20,5 +22,7 @@ generation_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 enable_attention_dp: false + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml index cf65a53f4ff..25612d4a784 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_attention_dp_overlap.yaml @@ -13,6 +13,8 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -21,5 +23,7 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: true disable_overlap_scheduler: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml index eeac6135487..facc4603306 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml @@ -13,6 +13,8 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 enable_attention_dp: true + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -22,3 +24,5 @@ generation_servers: enable_attention_dp: false urls: - "localhost:8002" + cache_transceiver_config: + backend: default diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml index e4ee818e782..729bdf2cf99 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1.yaml @@ -9,12 +9,16 @@ context_servers: num_instances: 1 tensor_parallel_size: 2 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 2 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml index 2e64638bafe..bde3132f8a1 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp1_trt_backend.yaml @@ -6,12 +6,16 @@ context_servers: num_instances: 1 tensor_parallel_size: 2 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 2 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml index 5c560cb77aa..1bc20842867 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml @@ -9,11 +9,15 @@ context_servers: num_instances: 1 tensor_parallel_size: 2 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 1 tensor_parallel_size: 2 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml index 94ac965b19a..28d4c3556e2 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp.yaml @@ -10,6 +10,8 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -17,5 +19,7 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: True + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml index 0cb3ef15351..0d05bef459e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one.yaml @@ -10,6 +10,8 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: true + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -17,5 +19,7 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: false + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml index 8403a61fd6d..fa771b9e30f 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_one_mtp.yaml @@ -13,6 +13,8 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: true + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -20,5 +22,8 @@ generation_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 enable_attention_dp: false + cache_transceiver_config: + backend: default + urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml index c893c8fff83..9398f7ddd26 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap.yaml @@ -10,6 +10,8 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: True disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -18,5 +20,7 @@ generation_servers: pipeline_parallel_size: 1 enable_attention_dp: True disable_overlap_scheduler: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml index 1171fb4f102..f8c04735eb3 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_attention_dp_overlap_cuda_graph.yaml @@ -9,6 +9,8 @@ context_servers: pipeline_parallel_size: 1 enable_attention_dp: true disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -19,5 +21,7 @@ generation_servers: cuda_graph_config: enable_padding: False disable_overlap_scheduler: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml new file mode 100644 index 00000000000..912178b7f62 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml @@ -0,0 +1,22 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/fp8 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +context_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "mpi" + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "mpi" + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml new file mode 100644 index 00000000000..e4fd09a1ce1 --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml @@ -0,0 +1,22 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/fp8 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +context_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "nixl" + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "nixl" + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml index 18acc70f9ac..9ace31717ec 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_overlap_cuda_graph.yaml @@ -8,6 +8,8 @@ context_servers: tensor_parallel_size: 2 pipeline_parallel_size: 1 disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -17,5 +19,7 @@ generation_servers: cuda_graph_config: enable_padding: False disable_overlap_scheduler: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml new file mode 100644 index 00000000000..b21637529bf --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml @@ -0,0 +1,22 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/fp8 +free_gpu_memory_fraction: 0.25 +backend: "pytorch" +disable_overlap_scheduler: True +context_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "ucx" + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "ucx" + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml index 7009df9fd0f..8b992d210cc 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_cuda_graph_padding.yaml @@ -15,6 +15,8 @@ context_servers: cuda_graph_config: batch_sizes: [1,3000] disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -31,5 +33,7 @@ generation_servers: enable_padding: True batch_sizes: [1,4,8,16,24,32] disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml index 6777ca485d3..f42ea826c05 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only.yaml @@ -13,6 +13,8 @@ generation_servers: free_gpu_memory_fraction: 0.2 enable_block_reuse: False enable_partial_reuse: False + cache_transceiver_config: + backend: default print_iter_log: True urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml index a0b31eb419c..386a8fba01f 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_gen_only_trt_backend.yaml @@ -11,6 +11,8 @@ generation_servers: free_gpu_memory_fraction: 0.2 enable_block_reuse: False enable_partial_reuse: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" - "localhost:8003" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml index fd42b7fdc0e..f0766a9c6d2 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_load_balance.yaml @@ -18,6 +18,8 @@ context_servers: free_gpu_memory_fraction: 0.15 enable_partial_reuse: False disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" - "localhost:8002" @@ -35,6 +37,8 @@ generation_servers: free_gpu_memory_fraction: 0.15 enable_partial_reuse: False disable_overlap_scheduler: False + cache_transceiver_config: + backend: "default" urls: - "localhost:8003" - "localhost:8004" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml index e3d8cdb60b9..31e429c440e 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_mixed.yaml @@ -9,12 +9,16 @@ context_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 2 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8001" - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml index 667262df4a3..2f779f598ac 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ngram.yaml @@ -8,12 +8,16 @@ context_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "default" urls: - "localhost:8001" generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: "default" urls: - "localhost:8002" speculative_config: diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml index ea6719cb55d..5cdafaed341 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml @@ -15,6 +15,8 @@ context_servers: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False disable_overlap_scheduler: True + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: @@ -28,5 +30,7 @@ generation_servers: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False disable_overlap_scheduler: False + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml index 9b018dfcd98..fa57d987de4 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trt_backend.yaml @@ -8,11 +8,15 @@ context_servers: pipeline_parallel_size: 1 kv_cache_config: free_gpu_memory_fraction: 0.2 + cache_transceiver_config: + backend: default urls: - "localhost:8001" generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 + cache_transceiver_config: + backend: default urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml index 7e4f0ddec00..b7ecb48b306 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_trtllm_sampler.yaml @@ -15,6 +15,8 @@ context_servers: kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False + cache_transceiver_config: + backend: "default" disable_overlap_scheduler: True urls: - "localhost:8001" @@ -29,6 +31,8 @@ generation_servers: kv_cache_config: free_gpu_memory_fraction: 0.2 enable_partial_reuse: False + cache_transceiver_config: + backend: "default" disable_overlap_scheduler: False urls: - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 8648f59d357..251df5bc9dc 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -59,9 +59,17 @@ def get_test_config(test_desc, example_dir, test_root): "conditional": (2, f"{test_configs_root}/disagg_config_conditional.yaml"), "ngram": (2, f"{test_configs_root}/disagg_config_ngram.yaml"), - "deepseek_v3_lite_fp8": + "deepseek_v3_lite_fp8_mpi": (4, - f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite.yaml" + f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml" + ), + "deepseek_v3_lite_fp8_ucx": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_ucx.yaml" + ), + "deepseek_v3_lite_fp8_nixl": + (4, + f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_nixl.yaml" ), "deepseek_v3_lite_fp8_tp1": (2, @@ -129,6 +137,8 @@ def run_disaggregated_test(example_dir, cwd=None): """Run disaggregated test with given configuration.""" cleanup_output_files() + run_env = env.copy() + run_env["UCX_TLS"] = "^ib" num_ranks, config_file = get_test_config(test_desc, example_dir, os.path.dirname(__file__)) @@ -151,14 +161,14 @@ def run_disaggregated_test(example_dir, popen(workers_cmd, stdout=output_workers, stderr=subprocess.STDOUT, - env=env, + env=run_env, cwd=cwd) as workers_proc, # Start server open('output_disagg.log', 'w') as output_disagg, popen(server_cmd, stdout=output_disagg, stderr=subprocess.STDOUT, - env=env, + env=run_env, cwd=cwd) as server_proc): client_dir = f"{example_dir}/clients" for _ in range(num_iters): @@ -525,9 +535,10 @@ def test_disaggregated_ngram(disaggregated_test_root, llm_venv, @pytest.mark.skip_less_device(4) @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], indirect=True) -def test_disaggregated_deepseek_v3_lite_fp8(disaggregated_test_root, - disaggregated_example_root, - llm_venv, deepseek_v3_model_root): +def test_disaggregated_deepseek_v3_lite_fp8_mpi(disaggregated_test_root, + disaggregated_example_root, + llm_venv, + deepseek_v3_model_root): src_dst_dict = { deepseek_v3_model_root: f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8", @@ -536,10 +547,11 @@ def test_disaggregated_deepseek_v3_lite_fp8(disaggregated_test_root, if not os.path.islink(dst): os.makedirs(os.path.dirname(dst), exist_ok=True) os.symlink(src, dst, target_is_directory=True) - + env = llm_venv._new_env.copy() + env["TRTLLM_USE_MPI_KVCACHE"] = "1" run_disaggregated_test(disaggregated_example_root, - "deepseek_v3_lite_fp8", - env=llm_venv._new_env, + "deepseek_v3_lite_fp8_mpi", + env=env, cwd=llm_venv.get_working_directory()) @@ -607,7 +619,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_ucx(disaggregated_test_root, env["TRTLLM_USE_UCX_KVCACHE"] = "1" env["UCX_TLS"] = "^ib" run_disaggregated_test(disaggregated_example_root, - "deepseek_v3_lite_fp8", + "deepseek_v3_lite_fp8_ucx", env=env, cwd=llm_venv.get_working_directory()) @@ -633,7 +645,7 @@ def test_disaggregated_deepseek_v3_lite_fp8_nixl(disaggregated_test_root, env["TRTLLM_USE_NIXL_KVCACHE"] = "1" env["UCX_TLS"] = "^ib" run_disaggregated_test(disaggregated_example_root, - "deepseek_v3_lite_fp8", + "deepseek_v3_lite_fp8_nixl", env=env, cwd=llm_venv.get_working_directory()) diff --git a/tests/integration/defs/disaggregated/test_disaggregated_etcd.py b/tests/integration/defs/disaggregated/test_disaggregated_etcd.py index 5d200d82e73..7521ecde42f 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_etcd.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_etcd.py @@ -244,14 +244,16 @@ def create_config_files(config): context_config_content = """pytorch_backend_config: disable_overlap_scheduler: True cache_transceiver_config: - max_num_tokens: 2048""" + backend: "default" + max_tokens_in_buffer: 2048""" with open(CONTEXT_CONFIG_FILE, 'w') as file: file.write(context_config_content) # Create generation config file generation_config_content = """cache_transceiver_config: - max_num_tokens: 2048""" + backend: "default" + max_tokens_in_buffer: 2048""" with open(GENERATION_CONFIG_FILE, 'w') as file: file.write(generation_config_content) diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index e0ab570ec5c..1e1859f5aa6 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -11,7 +11,8 @@ from tensorrt_llm import LLM, DisaggregatedParams, SamplingParams from tensorrt_llm._utils import set_mpi_comm -from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, MpiCommSession +from tensorrt_llm.llmapi import (CacheTransceiverConfig, CudaGraphConfig, + KvCacheConfig, MpiCommSession) from tensorrt_llm.llmapi.llm_args import EagleDecodingConfig cloudpickle.register_pickle_by_value(sys.modules[__name__]) @@ -43,7 +44,8 @@ def model_path(model_name): raise ValueError(f"Unknown model: {model_name}") -async def run_worker(kv_cache_config, pytorch_config, model_name, rank): +async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config, + model_name, rank): assert isinstance(pytorch_config, dict) print(f"Running worker {rank}") port_name = MPI.Lookup_name('my_port') @@ -59,7 +61,8 @@ async def run_worker(kv_cache_config, pytorch_config, model_name, rank): enable_chunked_prefill=False, **pytorch_config, _mpi_session=mpi_session, - kv_cache_config=kv_cache_config) + kv_cache_config=kv_cache_config, + cache_transceiver_config=cache_transceiver_config) print(f"LLM created") except Exception as e: print(f"Error creating LLM: {e}") @@ -103,9 +106,11 @@ def send_requests_to_worker(requests, worker_rank, intercomm): return responses -def worker_entry_point(kv_cache_config, pytorch_config, model_name, rank): +def worker_entry_point(kv_cache_config, cache_transceiver_config, + pytorch_config, model_name, rank): return asyncio.run( - run_worker(kv_cache_config, pytorch_config, model_name, rank)) + run_worker(kv_cache_config, cache_transceiver_config, pytorch_config, + model_name, rank)) def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, @@ -125,16 +130,19 @@ def verify_disaggregated(model, generation_overlap, enable_cuda_graph, prompt, cuda_graph_config=CudaGraphConfig() if enable_cuda_graph else None)) kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)] + cache_transceiver_configs = [ + CacheTransceiverConfig(backend="default") for _ in range(2) + ] model_names = [model_path(model) for _ in range(2)] ranks = [0, 1] worker_args = list( - zip(kv_cache_configs, worker_pytorch_configs, model_names, ranks)) + zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, + model_names, ranks)) port_name = MPI.Open_port() MPI.Publish_name('my_port', port_name) - with MPIPoolExecutor(max_workers=2, env={"TRTLLM_USE_MPI_KVCACHE": - "1"}) as executor: + with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor: futures = [] try: for worker_arg in worker_args: @@ -249,18 +257,21 @@ def test_disaggregated_llama_context_capacity(model, enable_cuda_graph, KvCacheConfig(max_tokens=128, enable_block_reuse=False, dtype="auto") for _ in range(2) ] + cache_transceiver_configs = [ + CacheTransceiverConfig(backend="default") for _ in range(2) + ] model_names = [model_path(model) for _ in range(2)] ranks = [0, 1] worker_args = list( - zip(kv_cache_configs, worker_pytorch_configs, model_names, ranks)) + zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs, + model_names, ranks)) port_name = MPI.Open_port() MPI.Publish_name('my_port', port_name) prompt = "European Union is a political and economic union of 27 countries. The European Union is headquartered in Brussels, Belgium. The first president of the European Union was Jean-Claude Juncker. The current president is Ursula von der Leyen. The European Union is a major economic and political entity." - with MPIPoolExecutor(max_workers=2, env={"TRTLLM_USE_MPI_KVCACHE": - "1"}) as executor: + with MPIPoolExecutor(max_workers=2, env={"UCX_TLS": "^ib"}) as executor: futures = [] try: for worker_arg in worker_args: diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index 0cf65a29aed..0b7a3d7384a 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -589,7 +589,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[T disaggregated/test_disaggregated.py::test_disaggregated_multi_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0] -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] +disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] diff --git a/tests/integration/test_lists/qa/llm_sanity_test.txt b/tests/integration/test_lists/qa/llm_sanity_test.txt index 19bf09b8b5e..5630dd47312 100644 --- a/tests/integration/test_lists/qa/llm_sanity_test.txt +++ b/tests/integration/test_lists/qa/llm_sanity_test.txt @@ -60,7 +60,7 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_att disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] +disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi[DeepSeek-V3-Lite-fp8] disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index 1599b73a44b..e5a6b700786 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -89,7 +89,7 @@ l0_dgx_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_cuda_graph_padding_4gpus[attention_dp=True-mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb - - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] + - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_mpi[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 5380afccf86..e9f4ed4401e 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -417,9 +417,6 @@ test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKI examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5374145) examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5373451) examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5373962) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5373962) -disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5373962) stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test] SKIP (https://nvbugs/5375646) examples/test_gemma.py::test_hf_gemma_fp8_base_bf16_multi_lora[gemma-2-9b-it] SKIP (https://nvbugs/5376087) full:GH200/disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_attention_dp_one[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5375966) diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 5d9460ffef0..935c4c9bfc3 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -2463,9 +2463,11 @@ def test_guided_decoding_config_pickle(): def test_cache_transceiver_config_pickle(): - config = trtllm.CacheTransceiverConfig(max_num_tokens=1024) + config = trtllm.CacheTransceiverConfig(backend="UCX", + max_tokens_in_buffer=1024) config_copy = pickle.loads(pickle.dumps(config)) - assert config_copy.max_num_tokens == config.max_num_tokens + assert config_copy.backend == config.backend + assert config_copy.max_tokens_in_buffer == config.max_tokens_in_buffer def test_executor_config_pickle():