Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ build
*.so
.idea
dist
**/__pycache__/
67 changes: 56 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# torch_memory_saver
# Torch Memory Saver

Allow torch tensor memory to be released and resumed later.
A PyTorch library that allows tensor memory to be temporarily released and resumed later.

API:
During the pause:
- Physical memory is released
- Virtual address is preserved

When resume:
- Virtual address is restored to the original one

Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.

## Examples

### Basic Example

```python
memory_saver = TorchMemorySaver()
import torch_memory_saver

memory_saver = torch_memory_saver.memory_saver

# 1. For tensors that wants to be paused, create them within `region`
with memory_saver.region():
x = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')

# 2. After `pause`, CUDA memory is released for those tensors.
# For example, check `nvidia-smi`'s memory usage to verify.
Expand All @@ -19,11 +32,43 @@ memory_saver.pause()
memory_saver.resume()
```

Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
### Multiple Tags Example

Please refer to https://github.com/sgl-project/sglang/issues/7009 for details.

```python
from torch_memory_saver import torch_memory_saver

# 1. Create tensors with different tags
with torch_memory_saver.region(tag="type1"):
tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')

with torch_memory_saver.region(tag="type2"):
tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')

# 2. Pause and resume with different tags selectively
torch_memory_saver.pause("type1")
torch_memory_saver.pause("type2")


torch_memory_saver.resume("type2")
torch_memory_saver.resume("type1")

torch_memory_saver.pause("type1")
torch_memory_saver.resume("type1")
```

## Development

```bash
pip install -e .
```

A `torch_memory_saver_cpp.abi3.so` will be built under `{your_workspace}/torch_memory_saver/` folder.

TODO:
You can use this command for local testing:
```bash
LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/simple.py

- [x] Implementation
- [x] Publish to pypi
- [ ] More tests and infra
- [ ] Documentation
LD_PRELOAD={your_workspace}/torch_memory_saver/torch_memory_saver_cpp.abi3.so python examples/rl_with_cuda_graph.py
```
63 changes: 47 additions & 16 deletions csrc/torch_memory_saver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <dlfcn.h>
#include <unordered_map>
#include <mutex>
#include <string>

// #define TMS_DEBUG_LOG

Expand Down Expand Up @@ -118,32 +119,32 @@ struct _AllocationMetadata {
size_t size;
CUdevice device;
CUmemGenericAllocationHandle allocHandle;
std::string tag;
};

class TorchMemorySaver {
public:
TorchMemorySaver() {}

cudaError_t malloc(void **ptr, size_t size) {
cudaError_t malloc(void **ptr, size_t size, const std::string& tag) {
CUdevice device;
CURESULT_CHECK(cuCtxGetDevice(&device));

CUmemGenericAllocationHandle allocHandle;
CUDAUtils::cu_mem_create(&allocHandle, size, device);

CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr *) ptr, size, 0, 0, 0));
CURESULT_CHECK(cuMemMap((CUdeviceptr) * ptr, size, 0, allocHandle, 0));
CUDAUtils::cu_mem_set_access(*ptr, size, device);

{
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle});
const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
allocation_metadata_.emplace(*ptr, _AllocationMetadata{size, device, allocHandle, tag});
}

#ifdef TMS_DEBUG_LOG
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
<< " allocHandle=" << allocHandle
<< " allocHandle=" << allocHandle << " tag=" << tag
<< std::endl;
#endif

Expand All @@ -166,39 +167,47 @@ class TorchMemorySaver {
#ifdef TMS_DEBUG_LOG
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
<< " ptr=" << ptr << " metadata.size=" << metadata.size
<< " metadata.allocHandle=" << metadata.allocHandle
<< " metadata.allocHandle=" << metadata.allocHandle << " tag=" << metadata.tag
<< std::endl;
#endif

return cudaSuccess;
}

void pause() {
void pause(const std::string& tag) {
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);

for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
void *ptr = it->first;
_AllocationMetadata metadata = it->second;

if (!tag.empty() && metadata.tag != tag) {
continue;
}

CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
CURESULT_CHECK(cuMemRelease(metadata.allocHandle));

#ifdef TMS_DEBUG_LOG
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " metadata.allocHandle="
<< metadata.allocHandle
<< metadata.allocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
<< std::endl;
#endif
}
}

void resume() {
void resume(const std::string& tag) {
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);

for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
void *ptr = it->first;
_AllocationMetadata &metadata = it->second;

if (!tag.empty() && metadata.tag != tag) {
continue;
}

CUmemGenericAllocationHandle newAllocHandle;
CUDAUtils::cu_mem_create(&newAllocHandle, metadata.size, metadata.device);

Expand All @@ -210,7 +219,7 @@ class TorchMemorySaver {
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " (old)metadata.allocHandle="
<< metadata.allocHandle
<< " (new)newAllocHandle=" << newAllocHandle
<< " (new)newAllocHandle=" << newAllocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
<< std::endl;
#endif

Expand All @@ -223,14 +232,18 @@ class TorchMemorySaver {
return instance;
}


private:
// Similar to torch's CUDACachingAllocator and CUDAPluggableAllocator
std::mutex allocator_metadata_mutex_;
std::unordered_map<void *, _AllocationMetadata> allocation_metadata_;
};


// ----------------------------------------------- region manager --------------------------------------------------

namespace RegionManager {
static thread_local bool is_interesting_region_ = false;
static thread_local std::string current_tag_ = "default";

void enter() {
#ifdef TMS_DEBUG_LOG
Expand All @@ -249,13 +262,21 @@ namespace RegionManager {
bool is_interesting_region() {
return is_interesting_region_;
}

void set_current_tag(const std::string& tag) {
current_tag_ = tag;
}

const std::string& get_current_tag() {
return current_tag_;
}
}

// ------------------------------------------------- entrypoints ------------------------------------------------

cudaError_t cudaMalloc(void **ptr, size_t size) {
if (RegionManager::is_interesting_region()) {
return TorchMemorySaver::instance().malloc(ptr, size);
return TorchMemorySaver::instance().malloc(ptr, size, RegionManager::get_current_tag());
} else {
return APIForwarder::call_real_cuda_malloc(ptr, size);
}
Expand All @@ -278,11 +299,21 @@ void tms_region_leave() {
RegionManager::leave();
}

void tms_pause() {
TorchMemorySaver::instance().pause();
void tms_set_current_tag(const char* tag) {
if (tag == nullptr) {
std::cerr << "[torch_memory_saver.cpp] FATAL: NULL tag passed to tms_set_current_tag" << std::endl;
exit(1);
}
RegionManager::set_current_tag(std::string(tag));
}

void tms_resume() {
TorchMemorySaver::instance().resume();
void tms_pause(const char* tag) {
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
TorchMemorySaver::instance().pause(tag_str);
}

void tms_resume(const char* tag) {
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
TorchMemorySaver::instance().resume(tag_str);
}
}
30 changes: 11 additions & 19 deletions examples/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import logging
import sys
import time
import os
from typing import Callable

import torch
from torch_memory_saver import TorchMemorySaver
from torch_memory_saver import torch_memory_saver

logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)

memory_saver = TorchMemorySaver()
dummy_tensor_size = (5, 100_000_000,)


Expand All @@ -22,7 +22,7 @@ def __init__(self):
self.create_buffers(1)

def create_buffers(self, value):
with memory_saver.region():
with torch_memory_saver.region(tag="kv_cache"):
# or model weights, etc
self.kv_buffer = torch.full(dummy_tensor_size, value, dtype=torch.float32, device='cuda')
print(f'create_buffers {_ptr(self.kv_buffer)=}')
Expand Down Expand Up @@ -72,20 +72,14 @@ def fn():
print(f'{static_output=}')
assert static_output == 101, f'{static_output=}'

# cache.clear_buffers()

# with with_pauseable_mode():
# big_tensor = torch.zeros((2_000_000_000,), dtype=torch.uint8, device='cuda')
# print(f'{big_tensor=}')

print('torch.cuda.empty_cache()')
torch.cuda.empty_cache()

print('sleep...')
time.sleep(3)

print('call memory_saver.pause')
memory_saver.pause()
print('call memory_saver.pause("kv_cache")')
torch_memory_saver.pause("kv_cache")

print('sleep...')
time.sleep(3)
Expand All @@ -100,17 +94,12 @@ def fn():
print('sleep...')
time.sleep(3)

# this should fail
# print(f'{cache.kv_buffer=}')

print('call memory_saver.resume')
memory_saver.resume()
print('call memory_saver.resume("kv_cache")')
torch_memory_saver.resume("kv_cache")

dummy = torch.zeros((3,), device='cuda')
print(f'{_ptr(dummy)=}')

# cache.create_buffers(2)

cache.kv_buffer[...] = 2

print('replay #2')
Expand All @@ -122,9 +111,12 @@ def fn():
print('sleep...')
time.sleep(3)

# print(f'{big_tensor=}')
print(f'{dummy=}')

# exit this process gracefully, bypassing CUDA cleanup
# Checkout for more details: https://github.com/fzyzcjy/torch_memory_saver/pull/18
os._exit(0)


if __name__ == '__main__':
run()
Loading