From 557ba69c9fa3429c89d8aaa372c987e3a5cebc7a Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Tue, 21 Oct 2025 01:32:34 -0700 Subject: [PATCH 01/10] Fix cublas handle not sufficient memory bug in multi gpu case. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 48 ++++++++++++++++++- docker/Makefile | 3 +- .../defs/accuracy/test_llm_api_pytorch.py | 10 +++- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index ae3810a255f..06484c5617a 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -259,8 +259,30 @@ std::shared_ptr getCublasHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + + CUcontext ctx; + cuCtxGetCurrent(&ctx); + + TLLM_LOG_DEBUG("Creating cublas handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx, + free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); + auto handle = std::unique_ptr(new cublasHandle_t); - TLLM_CUDA_CHECK(cublasCreate(handle.get())); + + cublasStatus_t status = cublasCreate(handle.get()); + + if (status != CUBLAS_STATUS_SUCCESS) + { + cudaMemGetInfo(&free_mem, &total_mem); + TLLM_THROW( + "Failed to create cublas handle. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.max_tokens or free_gpu_memory_fraction.", + status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, + total_mem / (1024 * 1024)); + } + return handle; }, [](cublasHandle_t* handle) @@ -277,8 +299,30 @@ std::shared_ptr getCublasLtHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { + size_t free_mem = 0, total_mem = 0; + cudaMemGetInfo(&free_mem, &total_mem); + + CUcontext ctx; + cuCtxGetCurrent(&ctx); + + TLLM_LOG_DEBUG("Creating cublasLt handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx, + free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); + auto handle = std::unique_ptr(new cublasLtHandle_t); - TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); + + cublasStatus_t status = cublasLtCreate(handle.get()); + + if (status != CUBLAS_STATUS_SUCCESS) + { + cudaMemGetInfo(&free_mem, &total_mem); + TLLM_THROW( + "Failed to create cublasLt handle. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.max_tokens or free_gpu_memory_fraction.", + status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, + total_mem / (1024 * 1024)); + } + return handle; }, [](cublasLtHandle_t* handle) diff --git a/docker/Makefile b/docker/Makefile index b51ae8dfc25..4bfb8a687a8 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -121,7 +121,7 @@ endef @echo "Pulling docker image: $(IMAGE_WITH_TAG)" docker pull $(IMAGE_WITH_TAG) -DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit stack=67108864 $(if $(filter 0,$(IS_ROOTLESS)),--ulimit memlock=-1) +DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit stack=67108864 $(if $(filter 0,$(IS_ROOTLESS)),--ulimit memlock=-1) --privileged DOCKER_RUN_ARGS ?= # Check if NVIDIA_VISIBLE_DEVICES is set and not empty NVIDIA_VISIBLE_DEVICES_VAL = $(shell echo $$NVIDIA_VISIBLE_DEVICES) @@ -156,6 +156,7 @@ endif docker run $(DOCKER_RUN_OPTS) $(DOCKER_RUN_ARGS) \ $(GPU_OPTS) \ --volume $(SOURCE_DIR):$(CODE_DIR) \ + --volume /home/scratch.trt_llm_data/:/scratch.trt_llm_data/ \ $(EXTRA_VOLUMES) \ $(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \ --env "CCACHE_DIR=$(CCACHE_DIR)" \ diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 2d1d7710b13..9c40c1226c5 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -155,7 +155,10 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile): disable_overlap_scheduler=torch_compile, ) if fp8kv: - pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8") + pytorch_config["kv_cache_config"] = KvCacheConfig( + dtype="fp8", + # max_tokens=100000, # Limit tokens to prevent no room for CUBLAS handles + ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", **pytorch_config) as llm: @@ -189,7 +192,10 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, disable_overlap_scheduler=torch_compile, ) if fp8kv: - pytorch_config["kv_cache_config"] = KvCacheConfig(dtype="fp8") + pytorch_config["kv_cache_config"] = KvCacheConfig( + dtype="fp8", + # max_tokens=100000, # Limit tokens to prevent no room for CUBLAS handles + ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", tensor_parallel_size=tp_size, From 4acd64982045ad281546eb95826dd533556926f1 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Tue, 21 Oct 2025 02:57:43 -0700 Subject: [PATCH 02/10] Fix cublas handle not sufficient memory bug in multi gpu case and reduce max tokens in kv cache config. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- docker/Makefile | 3 +-- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docker/Makefile b/docker/Makefile index 4bfb8a687a8..b51ae8dfc25 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -121,7 +121,7 @@ endef @echo "Pulling docker image: $(IMAGE_WITH_TAG)" docker pull $(IMAGE_WITH_TAG) -DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit stack=67108864 $(if $(filter 0,$(IS_ROOTLESS)),--ulimit memlock=-1) --privileged +DOCKER_RUN_OPTS ?= --rm -it --ipc=host --ulimit stack=67108864 $(if $(filter 0,$(IS_ROOTLESS)),--ulimit memlock=-1) DOCKER_RUN_ARGS ?= # Check if NVIDIA_VISIBLE_DEVICES is set and not empty NVIDIA_VISIBLE_DEVICES_VAL = $(shell echo $$NVIDIA_VISIBLE_DEVICES) @@ -156,7 +156,6 @@ endif docker run $(DOCKER_RUN_OPTS) $(DOCKER_RUN_ARGS) \ $(GPU_OPTS) \ --volume $(SOURCE_DIR):$(CODE_DIR) \ - --volume /home/scratch.trt_llm_data/:/scratch.trt_llm_data/ \ $(EXTRA_VOLUMES) \ $(if $(and $(filter 1,$(LOCAL_USER)),$(shell [ -w "$(USER_CACHE_DIR)" ] && echo 1)),--volume $(USER_CACHE_DIR):/home/$(USER_NAME)/.cache:rw) \ --env "CCACHE_DIR=$(CCACHE_DIR)" \ diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 9c40c1226c5..117ea7994d6 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -157,7 +157,8 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile): if fp8kv: pytorch_config["kv_cache_config"] = KvCacheConfig( dtype="fp8", - # max_tokens=100000, # Limit tokens to prevent no room for CUBLAS handles + max_tokens= + 100000, # Limit tokens to prevent no room for cublas/cublasLt handles ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", @@ -194,7 +195,8 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, if fp8kv: pytorch_config["kv_cache_config"] = KvCacheConfig( dtype="fp8", - # max_tokens=100000, # Limit tokens to prevent no room for CUBLAS handles + max_tokens= + 100000, # Limit tokens to prevent no room for cublas/cublasLt handles ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", From 1b118283841e2b6eab8fddd73b218f49a81901aa Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Tue, 21 Oct 2025 23:42:58 -0700 Subject: [PATCH 03/10] Fix cublas handle not sufficient memory bug in multi gpu case and reduce memory fraction in kv cache config. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 50 +++++++++++-------- .../defs/accuracy/test_llm_api_pytorch.py | 8 +-- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 06484c5617a..719d5ad0419 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -252,6 +252,18 @@ class PerCudaCtxPerThreadSingletonCreator std::unordered_map, hash>* mObservers; }; +// Helper function to log memory usage - returns the memory values for potential error handling +static std::pair logMemoryUsage(char const* operation, CUcontext ctx) +{ + size_t free_mem = 0, total_mem = 0; + TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); + + TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, + free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); + + return {free_mem, total_mem}; +} + } // namespace std::shared_ptr getCublasHandle() @@ -259,21 +271,16 @@ std::shared_ptr getCublasHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { - size_t free_mem = 0, total_mem = 0; - cudaMemGetInfo(&free_mem, &total_mem); - - CUcontext ctx; - cuCtxGetCurrent(&ctx); + CUcontext ctx = getCurrentCudaCtx(); + auto [free_mem, total_mem] = logMemoryUsage("Creating cublas handle", ctx); - TLLM_LOG_DEBUG("Creating cublas handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx, - free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); - - auto handle = std::unique_ptr(new cublasHandle_t); + auto handle = std::make_unique(); cublasStatus_t status = cublasCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { + // Re-fetch memory info for error message (memory state might have changed) cudaMemGetInfo(&free_mem, &total_mem); TLLM_THROW( "Failed to create cublas handle. " @@ -287,7 +294,11 @@ std::shared_ptr getCublasHandle() }, [](cublasHandle_t* handle) { - TLLM_CUDA_CHECK(cublasDestroy(*handle)); + cublasStatus_t status = cublasDestroy(*handle); + if (status != CUBLAS_STATUS_SUCCESS) + { + TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status); + } delete handle; handle = nullptr; }); @@ -299,21 +310,16 @@ std::shared_ptr getCublasLtHandle() static PerCudaCtxPerThreadSingletonCreator creator( []() -> auto { - size_t free_mem = 0, total_mem = 0; - cudaMemGetInfo(&free_mem, &total_mem); - - CUcontext ctx; - cuCtxGetCurrent(&ctx); - - TLLM_LOG_DEBUG("Creating cublasLt handle: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", ctx, - free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); + CUcontext ctx = getCurrentCudaCtx(); + auto [free_mem, total_mem] = logMemoryUsage("Creating cublasLt handle", ctx); - auto handle = std::unique_ptr(new cublasLtHandle_t); + auto handle = std::make_unique(); cublasStatus_t status = cublasLtCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { + // Re-fetch memory info for error message (memory state might have changed) cudaMemGetInfo(&free_mem, &total_mem); TLLM_THROW( "Failed to create cublasLt handle. " @@ -327,7 +333,11 @@ std::shared_ptr getCublasLtHandle() }, [](cublasLtHandle_t* handle) { - TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); + cublasStatus_t status = cublasLtDestroy(*handle); + if (status != CUBLAS_STATUS_SUCCESS) + { + TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status); + } delete handle; handle = nullptr; }); diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 117ea7994d6..31baa6d5734 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -157,8 +157,8 @@ def test_fp8(self, fp8kv, attn_backend, torch_compile): if fp8kv: pytorch_config["kv_cache_config"] = KvCacheConfig( dtype="fp8", - max_tokens= - 100000, # Limit tokens to prevent no room for cublas/cublasLt handles + free_gpu_memory_fraction= + 0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", @@ -195,8 +195,8 @@ def test_fp8_4gpus(self, tp_size, pp_size, fp8kv, attn_backend, if fp8kv: pytorch_config["kv_cache_config"] = KvCacheConfig( dtype="fp8", - max_tokens= - 100000, # Limit tokens to prevent no room for cublas/cublasLt handles + free_gpu_memory_fraction= + 0.8, # Prevent cublas/cublasLt handle allocation memory insufficient errors ) with LLM( f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8", From f866e72fcf25fb8ce79a968e33a9930cbe8184f6 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:10:03 -0700 Subject: [PATCH 04/10] Refactor. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 33 +++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 719d5ad0419..c6d92ab296a 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -264,6 +264,21 @@ static std::pair logMemoryUsage(char const* operation, CUcontext return {free_mem, total_mem}; } +// Helper function to handle memory-related errors for cublas handle creation. +static void throwCublasMemoryError(char const* handleType, cublasStatus_t status, CUcontext ctx) +{ + // Re-fetch and log current memory state for debugging + auto [free_mem, total_mem] = logMemoryUsage( + (std::string("Failed to create ") + handleType + " - logging current memory state").c_str(), ctx); + + TLLM_THROW( + "Failed to create %s. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.free_gpu_memory_fraction.", + handleType, status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, + total_mem / (1024 * 1024)); +} + } // namespace std::shared_ptr getCublasHandle() @@ -280,14 +295,7 @@ std::shared_ptr getCublasHandle() if (status != CUBLAS_STATUS_SUCCESS) { - // Re-fetch memory info for error message (memory state might have changed) - cudaMemGetInfo(&free_mem, &total_mem); - TLLM_THROW( - "Failed to create cublas handle. " - "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " - "Consider reducing kv_cache_config.max_tokens or free_gpu_memory_fraction.", - status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, - total_mem / (1024 * 1024)); + throwCublasMemoryError("cublas handle", status, ctx); } return handle; @@ -319,14 +327,7 @@ std::shared_ptr getCublasLtHandle() if (status != CUBLAS_STATUS_SUCCESS) { - // Re-fetch memory info for error message (memory state might have changed) - cudaMemGetInfo(&free_mem, &total_mem); - TLLM_THROW( - "Failed to create cublasLt handle. " - "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " - "Consider reducing kv_cache_config.max_tokens or free_gpu_memory_fraction.", - status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, - total_mem / (1024 * 1024)); + throwCublasMemoryError("cublasLt handle", status, ctx); } return handle; From 99e87dd911e4886b4e37499225f463cafc9d1b72 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 22 Oct 2025 19:37:15 -0700 Subject: [PATCH 05/10] Refactor two helper function into one. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 46 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index c6d92ab296a..27d7872061e 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -252,31 +252,31 @@ class PerCudaCtxPerThreadSingletonCreator std::unordered_map, hash>* mObservers; }; -// Helper function to log memory usage - returns the memory values for potential error handling -static std::pair logMemoryUsage(char const* operation, CUcontext ctx) +// Unified helper function for memory logging and error handling +// If status is CUBLAS_STATUS_SUCCESS, it just logs memory usage +// Otherwise, it logs and throws an error with memory diagnostics +static void logMemoryAndHandleError(char const* operation, CUcontext ctx, cublasStatus_t status = CUBLAS_STATUS_SUCCESS) { size_t free_mem = 0, total_mem = 0; TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); - TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, - free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, total_mem / (1024 * 1024)); + size_t free_mb = free_mem / (1024 * 1024); + size_t total_mb = total_mem / (1024 * 1024); + float free_percent = (total_mem > 0) ? ((float) free_mem / total_mem * 100.0f) : 0.0f; - return {free_mem, total_mem}; -} + // Always log the memory state + TLLM_LOG_DEBUG( + "%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, free_mb, free_percent, total_mb); -// Helper function to handle memory-related errors for cublas handle creation. -static void throwCublasMemoryError(char const* handleType, cublasStatus_t status, CUcontext ctx) -{ - // Re-fetch and log current memory state for debugging - auto [free_mem, total_mem] = logMemoryUsage( - (std::string("Failed to create ") + handleType + " - logging current memory state").c_str(), ctx); - - TLLM_THROW( - "Failed to create %s. " - "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " - "Consider reducing kv_cache_config.free_gpu_memory_fraction.", - handleType, status, ctx, free_mem / (1024 * 1024), (float) free_mem / total_mem * 100.0, - total_mem / (1024 * 1024)); + // If there's an error, throw with details + if (status != CUBLAS_STATUS_SUCCESS) + { + TLLM_THROW( + "Failed to create %s. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.free_gpu_memory_fraction.", + operation, status, ctx, free_mb, free_percent, total_mb); + } } } // namespace @@ -287,7 +287,7 @@ std::shared_ptr getCublasHandle() []() -> auto { CUcontext ctx = getCurrentCudaCtx(); - auto [free_mem, total_mem] = logMemoryUsage("Creating cublas handle", ctx); + logMemoryAndHandleError("Creating cublas handle", ctx); auto handle = std::make_unique(); @@ -295,7 +295,7 @@ std::shared_ptr getCublasHandle() if (status != CUBLAS_STATUS_SUCCESS) { - throwCublasMemoryError("cublas handle", status, ctx); + logMemoryAndHandleError("cublas handle", ctx, status); } return handle; @@ -319,7 +319,7 @@ std::shared_ptr getCublasLtHandle() []() -> auto { CUcontext ctx = getCurrentCudaCtx(); - auto [free_mem, total_mem] = logMemoryUsage("Creating cublasLt handle", ctx); + logMemoryAndHandleError("Creating cublasLt handle", ctx); auto handle = std::make_unique(); @@ -327,7 +327,7 @@ std::shared_ptr getCublasLtHandle() if (status != CUBLAS_STATUS_SUCCESS) { - throwCublasMemoryError("cublasLt handle", status, ctx); + logMemoryAndHandleError("cublasLt handle", ctx, status); } return handle; From 0768b9cc7d847af3a9c5fba19761f18cc74db6c4 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 22 Oct 2025 20:16:57 -0700 Subject: [PATCH 06/10] Remove redundant error handling. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 27d7872061e..5d5123d68df 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -290,13 +290,10 @@ std::shared_ptr getCublasHandle() logMemoryAndHandleError("Creating cublas handle", ctx); auto handle = std::make_unique(); - cublasStatus_t status = cublasCreate(handle.get()); - if (status != CUBLAS_STATUS_SUCCESS) - { - logMemoryAndHandleError("cublas handle", ctx, status); - } + // This will log memory state and throw if status != SUCCESS + logMemoryAndHandleError("cublas handle", ctx, status); return handle; }, @@ -322,13 +319,10 @@ std::shared_ptr getCublasLtHandle() logMemoryAndHandleError("Creating cublasLt handle", ctx); auto handle = std::make_unique(); - cublasStatus_t status = cublasLtCreate(handle.get()); - if (status != CUBLAS_STATUS_SUCCESS) - { - logMemoryAndHandleError("cublasLt handle", ctx, status); - } + // This will log memory state and throw if status != SUCCESS + logMemoryAndHandleError("cublasLt handle", ctx, status); return handle; }, From a276b605554cfd7495ad1be715db6e1d1722fe11 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 22 Oct 2025 20:47:46 -0700 Subject: [PATCH 07/10] Change back to two separate helper functions Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 70 ++++++++++++++++++----------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 5d5123d68df..7725e673308 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -252,31 +252,47 @@ class PerCudaCtxPerThreadSingletonCreator std::unordered_map, hash>* mObservers; }; -// Unified helper function for memory logging and error handling -// If status is CUBLAS_STATUS_SUCCESS, it just logs memory usage -// Otherwise, it logs and throws an error with memory diagnostics -static void logMemoryAndHandleError(char const* operation, CUcontext ctx, cublasStatus_t status = CUBLAS_STATUS_SUCCESS) +// Structure to hold memory information +struct MemoryInfo +{ + size_t free_mb; + size_t total_mb; + float free_percent; +}; + +// Helper function to get current memory information +inline MemoryInfo getMemoryInfo() { size_t free_mem = 0, total_mem = 0; TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); - size_t free_mb = free_mem / (1024 * 1024); - size_t total_mb = total_mem / (1024 * 1024); - float free_percent = (total_mem > 0) ? ((float) free_mem / total_mem * 100.0f) : 0.0f; + const size_t free_mb = free_mem / (1024 * 1024); + const size_t total_mb = total_mem / (1024 * 1024); + float const free_percent = (total_mem > 0) ? (static_cast(free_mem) / total_mem * 100.0f) : 0.0f; - // Always log the memory state - TLLM_LOG_DEBUG( - "%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, free_mb, free_percent, total_mb); + return {free_mb, total_mb, free_percent}; +} - // If there's an error, throw with details - if (status != CUBLAS_STATUS_SUCCESS) - { - TLLM_THROW( - "Failed to create %s. " - "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " - "Consider reducing kv_cache_config.free_gpu_memory_fraction.", - operation, status, ctx, free_mb, free_percent, total_mb); - } +// Helper function to log current memory usage +inline void logMemoryUsage(char const* operation, CUcontext ctx) +{ + auto const mem = getMemoryInfo(); + TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb, + mem.free_percent, mem.total_mb); +} + +// Helper function to throw +inline void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) +{ + + logMemoryUsage(operation, ctx); + + auto const mem = getMemoryInfo(); + TLLM_THROW( + "Failed to create %s. " + "Status: %d, Context: %p, Free Memory: %zu MB (%.1f%%), Total: %zu MB. " + "Consider reducing kv_cache_config.free_gpu_memory_fraction.", + operation, status, ctx, mem.free_mb, mem.free_percent, mem.total_mb); } } // namespace @@ -287,13 +303,15 @@ std::shared_ptr getCublasHandle() []() -> auto { CUcontext ctx = getCurrentCudaCtx(); - logMemoryAndHandleError("Creating cublas handle", ctx); + logMemoryUsage("Creating cublas handle", ctx); auto handle = std::make_unique(); cublasStatus_t status = cublasCreate(handle.get()); - // This will log memory state and throw if status != SUCCESS - logMemoryAndHandleError("cublas handle", ctx, status); + if (status != CUBLAS_STATUS_SUCCESS) + { + throwCublasErrorWithMemInfo("cublas handle", ctx, status); + } return handle; }, @@ -316,13 +334,15 @@ std::shared_ptr getCublasLtHandle() []() -> auto { CUcontext ctx = getCurrentCudaCtx(); - logMemoryAndHandleError("Creating cublasLt handle", ctx); + logMemoryUsage("Creating cublasLt handle", ctx); auto handle = std::make_unique(); cublasStatus_t status = cublasLtCreate(handle.get()); - // This will log memory state and throw if status != SUCCESS - logMemoryAndHandleError("cublasLt handle", ctx, status); + if (status != CUBLAS_STATUS_SUCCESS) + { + throwCublasErrorWithMemInfo("cublasLt handle", ctx, status); + } return handle; }, From 119febbb43a857b510f3b7b16acd43badb921bd1 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Thu, 23 Oct 2025 00:49:21 -0700 Subject: [PATCH 08/10] Final change. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 7725e673308..a909a8579ff 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -284,9 +284,6 @@ inline void logMemoryUsage(char const* operation, CUcontext ctx) // Helper function to throw inline void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) { - - logMemoryUsage(operation, ctx); - auto const mem = getMemoryInfo(); TLLM_THROW( "Failed to create %s. " From 5ff3ba2163aebab913531c0aa44f3245fc910982 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Thu, 23 Oct 2025 01:27:13 -0700 Subject: [PATCH 09/10] Final minor change. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index a909a8579ff..8f03291b54a 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -261,20 +261,20 @@ struct MemoryInfo }; // Helper function to get current memory information -inline MemoryInfo getMemoryInfo() +MemoryInfo getMemoryInfo() { size_t free_mem = 0, total_mem = 0; TLLM_CUDA_CHECK(cudaMemGetInfo(&free_mem, &total_mem)); - const size_t free_mb = free_mem / (1024 * 1024); - const size_t total_mb = total_mem / (1024 * 1024); + size_t const free_mb = free_mem / (1024 * 1024); + size_t const total_mb = total_mem / (1024 * 1024); float const free_percent = (total_mem > 0) ? (static_cast(free_mem) / total_mem * 100.0f) : 0.0f; return {free_mb, total_mb, free_percent}; } // Helper function to log current memory usage -inline void logMemoryUsage(char const* operation, CUcontext ctx) +void logMemoryUsage(char const* operation, CUcontext ctx) { auto const mem = getMemoryInfo(); TLLM_LOG_DEBUG("%s: Context=%p, Free Memory=%zu MB (%.1f%%), Total=%zu MB", operation, ctx, mem.free_mb, @@ -282,7 +282,7 @@ inline void logMemoryUsage(char const* operation, CUcontext ctx) } // Helper function to throw -inline void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) +void throwCublasErrorWithMemInfo(char const* operation, CUcontext ctx, cublasStatus_t status) { auto const mem = getMemoryInfo(); TLLM_THROW( From 18cfb69e9d2246aac2b91d6b7307ef9ba5f6bf16 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Tue, 4 Nov 2025 19:12:41 -0800 Subject: [PATCH 10/10] Change cublasStatus_t to auto. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- cpp/tensorrt_llm/common/opUtils.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/tensorrt_llm/common/opUtils.cpp b/cpp/tensorrt_llm/common/opUtils.cpp index 8f03291b54a..736cd1c48d0 100644 --- a/cpp/tensorrt_llm/common/opUtils.cpp +++ b/cpp/tensorrt_llm/common/opUtils.cpp @@ -303,7 +303,7 @@ std::shared_ptr getCublasHandle() logMemoryUsage("Creating cublas handle", ctx); auto handle = std::make_unique(); - cublasStatus_t status = cublasCreate(handle.get()); + auto status = cublasCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { @@ -314,7 +314,7 @@ std::shared_ptr getCublasHandle() }, [](cublasHandle_t* handle) { - cublasStatus_t status = cublasDestroy(*handle); + auto status = cublasDestroy(*handle); if (status != CUBLAS_STATUS_SUCCESS) { TLLM_LOG_WARNING("Failed to destroy cublas handle. Status: %d", status); @@ -334,7 +334,7 @@ std::shared_ptr getCublasLtHandle() logMemoryUsage("Creating cublasLt handle", ctx); auto handle = std::make_unique(); - cublasStatus_t status = cublasLtCreate(handle.get()); + auto status = cublasLtCreate(handle.get()); if (status != CUBLAS_STATUS_SUCCESS) { @@ -345,7 +345,7 @@ std::shared_ptr getCublasLtHandle() }, [](cublasLtHandle_t* handle) { - cublasStatus_t status = cublasLtDestroy(*handle); + auto status = cublasLtDestroy(*handle); if (status != CUBLAS_STATUS_SUCCESS) { TLLM_LOG_WARNING("Failed to destroy cublasLt handle. Status: %d", status);