[Hardware] Replace memory related torch.cuda APIs #37031
[Hardware] Replace memory related torch.cuda APIs #37031hmellor merged 7 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request continues the effort to abstract away hardware-specific APIs by replacing several torch.cuda memory management functions with their torch.accelerator counterparts. The changes are consistently applied across various parts of the codebase, including benchmarks, tests, and utility modules. A corresponding update to the pre-commit hook ensures these new standards are maintained. My review identifies one area where a CUDA-specific conditional check remains, limiting the hardware-agnostic benefit of the refactoring on other platforms like ROCm. I've provided a suggestion to address this.
| # to have test coverage on peak memory for online quantization. | ||
| if current_platform.is_cuda(): | ||
| peak_memory = torch.cuda.max_memory_allocated() | ||
| peak_memory = torch.accelerator.max_memory_allocated() |
There was a problem hiding this comment.
The change to torch.accelerator.max_memory_allocated() is correct, but it's inside a if current_platform.is_cuda(): block on line 66. Since torch.accelerator is designed to be device-agnostic (working on CUDA, ROCm, etc.), this condition is now too restrictive and will prevent peak memory logging on other GPU platforms like ROCm.
To ensure this logging works on all supported GPU-like devices, consider broadening the condition. For example:
if current_platform.is_cuda_alike():There was a problem hiding this comment.
I think we shuold use if not current_platform.is_cpu() here.
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
### What this PR does / why we need it? 1.fix "TypeError: get_attn_backend() remove variable": [Refactor `check_and_update_config`](vllm-project/vllm#35122) 2.fix [Rename `compile_ranges_split_points` to `compile_ranges_endpoints`](vllm-project/vllm#36027) 3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace memory related torch.cuda APIs"](vllm-project/vllm#37031) 4.fix [Support multiple KV groups in OffloadingSpec ](vllm-project/vllm#36610) removed self.offloaded_block_size and changed self.gpu_block_size from a scalar to a tuple of per-group block sizes, adding block_size_factor. 5.fix [Consolidate SupportsEagle](vllm-project/vllm#36063) renamed get_eagle3_aux_hidden_state_layers() to get_eagle3_default_aux_hidden_state_layers() and added a supports_eagle3() guard before calling it. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? E2E - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8a68046 --------- Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: Claude Code <noreply@anthropic.com>
### What this PR does / why we need it? 1.fix "TypeError: get_attn_backend() remove variable": [Refactor `check_and_update_config`](vllm-project/vllm#35122) 2.fix [Rename `compile_ranges_split_points` to `compile_ranges_endpoints`](vllm-project/vllm#36027) 3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace memory related torch.cuda APIs"](vllm-project/vllm#37031) 4.fix [Support multiple KV groups in OffloadingSpec ](vllm-project/vllm#36610) removed self.offloaded_block_size and changed self.gpu_block_size from a scalar to a tuple of per-group block sizes, adding block_size_factor. 5.fix [Consolidate SupportsEagle](vllm-project/vllm#36063) renamed get_eagle3_aux_hidden_state_layers() to get_eagle3_default_aux_hidden_state_layers() and added a supports_eagle3() guard before calling it. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? E2E - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8a68046 --------- Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: Claude Code <noreply@anthropic.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
### What this PR does / why we need it? 1.fix "TypeError: get_attn_backend() remove variable": [Refactor `check_and_update_config`](vllm-project/vllm#35122) 2.fix [Rename `compile_ranges_split_points` to `compile_ranges_endpoints`](vllm-project/vllm#36027) 3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace memory related torch.cuda APIs"](vllm-project/vllm#37031) 4.fix [Support multiple KV groups in OffloadingSpec ](vllm-project/vllm#36610) removed self.offloaded_block_size and changed self.gpu_block_size from a scalar to a tuple of per-group block sizes, adding block_size_factor. 5.fix [Consolidate SupportsEagle](vllm-project/vllm#36063) renamed get_eagle3_aux_hidden_state_layers() to get_eagle3_default_aux_hidden_state_layers() and added a supports_eagle3() guard before calling it. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? E2E - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8a68046 --------- Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: Claude Code <noreply@anthropic.com>
### What this PR does / why we need it? 1.fix "TypeError: get_attn_backend() remove variable": [Refactor `check_and_update_config`](vllm-project/vllm#35122) 2.fix [Rename `compile_ranges_split_points` to `compile_ranges_endpoints`](vllm-project/vllm#36027) 3.fix "RuntimeError: device_allocator not a DeviceAllocator":[Replace memory related torch.cuda APIs"](vllm-project/vllm#37031) 4.fix [Support multiple KV groups in OffloadingSpec ](vllm-project/vllm#36610) removed self.offloaded_block_size and changed self.gpu_block_size from a scalar to a tuple of per-group block sizes, adding block_size_factor. 5.fix [Consolidate SupportsEagle](vllm-project/vllm#36063) renamed get_eagle3_aux_hidden_state_layers() to get_eagle3_default_aux_hidden_state_layers() and added a supports_eagle3() guard before calling it. ### Does this PR introduce _any_ user-facing change? NA ### How was this patch tested? E2E - vLLM version: v0.17.0 - vLLM main: vllm-project/vllm@8a68046 --------- Signed-off-by: leo-pony <nengjunma@outlook.com> Co-authored-by: Claude Code <noreply@anthropic.com>
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
Purpose
part of #30679
this PR replace below APIs with torch.accelerator:
Test Plan
CI
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.