Skip to content

Conversation

@hebiao064
Copy link
Contributor

@hebiao064 hebiao064 commented Jun 11, 2025

Motivation

Closes: sgl-project/sglang#7009

In RL Ecosystem which use colocate design like verl, we need to offload training model and load serving model & KV Cache frequently.

Background

  • Currently SGLang is using torch_memory_saver to pause and resume.
  • torch_memory_saver is a open source repo that provided easy to use api to hack cudaMalloc and cudaFree to make sure the virtual address could be consistent after pause and resume, which is critical to ensure CUDA Graph work.
  • CUDA Graph is critical to make sure SGLang runs faster in decoding phases.

Here is the current behavior of VERL + SGLang

Image

  1. During Training, we have training model and optimizer state in the GPU Memory, and once training is done, we will offload optimizer state to cpu and keep the model weights in GPU, which is needed in Update Weight.
  2. During Update Weight, we awake the SGLang engine, so those paused memory of Model Weights and KV Cache will come back. Then we update model from training model to serving model on the fly using the api: update_weights_in_tensor
  3. After Model being updated, we delete the training model from GPU Memory.

Above design works pretty well so far, however, this would waste a big chunk of GPU Memory during rollout, which could cause a few issues we've seen so far:

  • Small KV Cache: We need to use relative lower number of mem fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV Cache has less tokens, we will hit RuntimeError: Prefill out of memory. Try to lower your batch size. when we try prefill large number of requests.
  • Out of Memory: If we use mem fraction ratio 0.8 and run RL for 32B model on 8 H100, it will OOM during update weight

Proposal

Image

  1. During Training, we do the same
  2. During Update Weight Stage 1, we awake the model weights from SGLang and then update weights
  3. During Update Weight Stage 2, we delete the training model weights from GPU Memory
  4. Awake the SGLang's KV Cache

Image

Benefit

With above feature, we can train larger model with same GPU, we can also make training/rollout more efficient given we can allocate larger KV Cache

Execution Plan: Keep using Singleton and provide tag based pause/resume

@hebiao064 hebiao064 changed the title [Alternative of Multi Instance Solution] Singleton with tag based resume Multi-Stage Awake: Support tag-based Resume and Pause Jun 15, 2025
@fzyzcjy fzyzcjy mentioned this pull request Jun 15, 2025
1 task
@fzyzcjy
Copy link
Owner

fzyzcjy commented Jun 15, 2025

btw I see CUDA graph in your figure, thus #21

Copy link
Owner

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! Some nits

Copy link
Owner

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only a bit of nits!

Copy link
Owner

@fzyzcjy fzyzcjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only some tiny nits (if you have time maybe spend 3 minute to change and I am ready to merge; if in a hurry I am also ok for current code)

@fzyzcjy fzyzcjy merged commit f515992 into fzyzcjy:master Jun 17, 2025
zhaochenyang20 added a commit to volcengine/verl that referenced this pull request Jun 23, 2025
Co-authored with: MrAta ([email protected])

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

### Motivation

In RL Ecosystem which use colocate design like
[verl](https://github.com/volcengine/verl/tree/main), we need to offload
training model and load serving model & KV Cache frequently.


#### Background
- Currently SGLang is using
[torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to
pause and resume.
- [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is
a open source repo that provided easy to use api to hack **cudaMalloc**
and **cudaFree** to make sure the virtual address could be consistent
after pause and resume, which is critical to ensure CUDA Graph work.
- CUDA Graph is critical to make sure SGLang runs faster in decoding
phases.


#### Here is the current behavior of VERL + SGLang


![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8)

1. During Training, we have training model and optimizer state in the
GPU Memory, and once training is done, we will offload optimizer state
to cpu and keep the model weights in GPU, which is needed in Update
Weight.
2. During Update Weight, we awake the SGLang engine, so those paused
memory of Model Weights and KV Cache will come back. Then we update
model from training model to serving model on the fly using the api:
`update_weights_in_tensor`
3. After Model being updated, we delete the training model from GPU
Memory.


Above design works pretty well so far, however, this would waste a big
chunk of GPU Memory during rollout, which could cause a few issues we've
seen so far:
- **Small KV Cache**: We need to use relative lower number of mem
fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV
Cache has less tokens, we will hit `RuntimeError: Prefill out of memory.
Try to lower your batch size.` when we try prefill large number of
requests.
- **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B
model on 8 H100, it will OOM during update weight


#### Challenge
- `torch_memory_saver` currently only supports Singleton, hence SGLang
will pause and resume KV Cache + Weights together, they are treated as
the same group of memory controlled by the singleton
`torch_memory_saver` instance

#### Proposal

![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f)

1. During Training, we do the same
2. During Update Weight Stage 1, we awake the model weights from SGLang
and then update weights
3. During Update Weight Stage 2, we delete the training model weights
from GPU Memory
4. Awake the SGLang's KV Cache



![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25)


### Benefit
With above feature, we can train larger model with same GPU, we can also
make training/rollout more efficient given we can allocate larger KV
Cache

### Solution: Keep using Singleton and provide tag based pause/resume

- [x] Support tag based resume/pause:
fzyzcjy/torch_memory_saver#20
- [x] Support Multiple Stage Awake in SGLang:
sgl-project/sglang#7099
- [ ] Support Multiple Stage Awake in verl:
#1911

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

![Screenshot 2025-06-19 at 12 16
19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc)
![Screenshot 2025-06-19 at 12 13
14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Co-authored-by: Chayenne <[email protected]>
Sirius-L1 pushed a commit to Sirius-L1/verl that referenced this pull request Jun 24, 2025
Co-authored with: MrAta ([email protected])

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

### Motivation

In RL Ecosystem which use colocate design like
[verl](https://github.com/volcengine/verl/tree/main), we need to offload
training model and load serving model & KV Cache frequently.


#### Background
- Currently SGLang is using
[torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to
pause and resume.
- [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is
a open source repo that provided easy to use api to hack **cudaMalloc**
and **cudaFree** to make sure the virtual address could be consistent
after pause and resume, which is critical to ensure CUDA Graph work.
- CUDA Graph is critical to make sure SGLang runs faster in decoding
phases.


#### Here is the current behavior of VERL + SGLang


![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8)

1. During Training, we have training model and optimizer state in the
GPU Memory, and once training is done, we will offload optimizer state
to cpu and keep the model weights in GPU, which is needed in Update
Weight.
2. During Update Weight, we awake the SGLang engine, so those paused
memory of Model Weights and KV Cache will come back. Then we update
model from training model to serving model on the fly using the api:
`update_weights_in_tensor`
3. After Model being updated, we delete the training model from GPU
Memory.


Above design works pretty well so far, however, this would waste a big
chunk of GPU Memory during rollout, which could cause a few issues we've
seen so far:
- **Small KV Cache**: We need to use relative lower number of mem
fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV
Cache has less tokens, we will hit `RuntimeError: Prefill out of memory.
Try to lower your batch size.` when we try prefill large number of
requests.
- **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B
model on 8 H100, it will OOM during update weight


#### Challenge
- `torch_memory_saver` currently only supports Singleton, hence SGLang
will pause and resume KV Cache + Weights together, they are treated as
the same group of memory controlled by the singleton
`torch_memory_saver` instance

#### Proposal

![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f)

1. During Training, we do the same
2. During Update Weight Stage 1, we awake the model weights from SGLang
and then update weights
3. During Update Weight Stage 2, we delete the training model weights
from GPU Memory
4. Awake the SGLang's KV Cache



![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25)


### Benefit
With above feature, we can train larger model with same GPU, we can also
make training/rollout more efficient given we can allocate larger KV
Cache

### Solution: Keep using Singleton and provide tag based pause/resume

- [x] Support tag based resume/pause:
fzyzcjy/torch_memory_saver#20
- [x] Support Multiple Stage Awake in SGLang:
sgl-project/sglang#7099
- [ ] Support Multiple Stage Awake in verl:
volcengine#1911

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

![Screenshot 2025-06-19 at 12 16
19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc)
![Screenshot 2025-06-19 at 12 13
14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Co-authored-by: Chayenne <[email protected]>
Tyizhanshen pushed a commit to HyperdriveHustle/verl that referenced this pull request Jul 1, 2025
Co-authored with: MrAta ([email protected])

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

### Motivation

In RL Ecosystem which use colocate design like
[verl](https://github.com/volcengine/verl/tree/main), we need to offload
training model and load serving model & KV Cache frequently.


#### Background
- Currently SGLang is using
[torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to
pause and resume.
- [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is
a open source repo that provided easy to use api to hack **cudaMalloc**
and **cudaFree** to make sure the virtual address could be consistent
after pause and resume, which is critical to ensure CUDA Graph work.
- CUDA Graph is critical to make sure SGLang runs faster in decoding
phases.


#### Here is the current behavior of VERL + SGLang


![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8)

1. During Training, we have training model and optimizer state in the
GPU Memory, and once training is done, we will offload optimizer state
to cpu and keep the model weights in GPU, which is needed in Update
Weight.
2. During Update Weight, we awake the SGLang engine, so those paused
memory of Model Weights and KV Cache will come back. Then we update
model from training model to serving model on the fly using the api:
`update_weights_in_tensor`
3. After Model being updated, we delete the training model from GPU
Memory.


Above design works pretty well so far, however, this would waste a big
chunk of GPU Memory during rollout, which could cause a few issues we've
seen so far:
- **Small KV Cache**: We need to use relative lower number of mem
fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV
Cache has less tokens, we will hit `RuntimeError: Prefill out of memory.
Try to lower your batch size.` when we try prefill large number of
requests.
- **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B
model on 8 H100, it will OOM during update weight


#### Challenge
- `torch_memory_saver` currently only supports Singleton, hence SGLang
will pause and resume KV Cache + Weights together, they are treated as
the same group of memory controlled by the singleton
`torch_memory_saver` instance

#### Proposal

![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f)

1. During Training, we do the same
2. During Update Weight Stage 1, we awake the model weights from SGLang
and then update weights
3. During Update Weight Stage 2, we delete the training model weights
from GPU Memory
4. Awake the SGLang's KV Cache



![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25)


### Benefit
With above feature, we can train larger model with same GPU, we can also
make training/rollout more efficient given we can allocate larger KV
Cache

### Solution: Keep using Singleton and provide tag based pause/resume

- [x] Support tag based resume/pause:
fzyzcjy/torch_memory_saver#20
- [x] Support Multiple Stage Awake in SGLang:
sgl-project/sglang#7099
- [ ] Support Multiple Stage Awake in verl:
volcengine#1911

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

![Screenshot 2025-06-19 at 12 16
19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc)
![Screenshot 2025-06-19 at 12 13
14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Co-authored-by: Chayenne <[email protected]>
@fzyzcjy
Copy link
Owner

fzyzcjy commented Jul 8, 2025

realize that, we may need one separate mem pool per tag in this mode. this is because, suppose:

  • tag a and allocate 1KB tensor x: indeed torch allocate 2MB memory M, and slice it to satisfy x, so we tag M with a
  • tag b and allocate 1KB tensor y: indeed torch reuse memory M, so tensor y is indeed with tag a

I will also make a pluggable allocator version which may not have this issue though

@hebiao064
Copy link
Contributor Author

hebiao064 commented Jul 8, 2025

realize that, we may need one separate mem pool per tag in this mode. this is because, suppose:

  • tag a and allocate 1KB tensor x: indeed torch allocate 2MB memory M, and slice it to satisfy x, so we tag M with a
  • tag b and allocate 1KB tensor y: indeed torch reuse memory M, so tensor y is indeed with tag a

I will also make a pluggable allocator version which may not have this issue though

Ah I think I do understand the mechanism you mentioned in your comment (just read this blog few days ago: https://zhuanlan.zhihu.com/p/493646010)

But I wonder how did you realize this problem? and what kind of issues are you foreseeing right now? Will it lead to illegal memory issue we recently hit?

Context:
the illegal memory issue was mitigated by these two PRs

@fzyzcjy
Copy link
Owner

fzyzcjy commented Jul 8, 2025

But I wonder how did you realize this problem?

b/c today I am refactoring and adding features to torch_memory_saver, and this comes to my mind.

and what kind of issues are you foreseeing right now?

I expect it to have issues like, tensors are wrongly tagged

@hebiao064
Copy link
Contributor Author

But I wonder how did you realize this problem?

b/c today I am refactoring and adding features to torch_memory_saver, and this comes to my mind.

and what kind of issues are you foreseeing right now?

I expect it to have issues like, tensors are wrongly tagged

I'm happy to fix the pool arrangement issue, but tbh I am not very clear about the next.

Should I create separate pool when we init those tensors (e.g: kv and weight) and make sure we are malloc memory using the pool passed in?

@fzyzcjy
Copy link
Owner

fzyzcjy commented Jul 9, 2025

my personal guess is that change

with torch.cuda.use_mem_pool(self._mem_pool):

to sth like

# init
self._mem_pools = defaultdict(lambda: torch.cuda.MemPool(...))

# region
mem_pool = self._mem_pools[(tag, enable_cpu_backup)]
... others unchanged

oseyosey pushed a commit to oseyosey/verl that referenced this pull request Jul 28, 2025
Co-authored with: MrAta ([email protected])

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

### Motivation

In RL Ecosystem which use colocate design like
[verl](https://github.com/volcengine/verl/tree/main), we need to offload
training model and load serving model & KV Cache frequently.


#### Background
- Currently SGLang is using
[torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to
pause and resume.
- [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is
a open source repo that provided easy to use api to hack **cudaMalloc**
and **cudaFree** to make sure the virtual address could be consistent
after pause and resume, which is critical to ensure CUDA Graph work.
- CUDA Graph is critical to make sure SGLang runs faster in decoding
phases.


#### Here is the current behavior of VERL + SGLang


![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8)

1. During Training, we have training model and optimizer state in the
GPU Memory, and once training is done, we will offload optimizer state
to cpu and keep the model weights in GPU, which is needed in Update
Weight.
2. During Update Weight, we awake the SGLang engine, so those paused
memory of Model Weights and KV Cache will come back. Then we update
model from training model to serving model on the fly using the api:
`update_weights_in_tensor`
3. After Model being updated, we delete the training model from GPU
Memory.


Above design works pretty well so far, however, this would waste a big
chunk of GPU Memory during rollout, which could cause a few issues we've
seen so far:
- **Small KV Cache**: We need to use relative lower number of mem
fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV
Cache has less tokens, we will hit `RuntimeError: Prefill out of memory.
Try to lower your batch size.` when we try prefill large number of
requests.
- **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B
model on 8 H100, it will OOM during update weight


#### Challenge
- `torch_memory_saver` currently only supports Singleton, hence SGLang
will pause and resume KV Cache + Weights together, they are treated as
the same group of memory controlled by the singleton
`torch_memory_saver` instance

#### Proposal

![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f)

1. During Training, we do the same
2. During Update Weight Stage 1, we awake the model weights from SGLang
and then update weights
3. During Update Weight Stage 2, we delete the training model weights
from GPU Memory
4. Awake the SGLang's KV Cache



![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25)


### Benefit
With above feature, we can train larger model with same GPU, we can also
make training/rollout more efficient given we can allocate larger KV
Cache

### Solution: Keep using Singleton and provide tag based pause/resume

- [x] Support tag based resume/pause:
fzyzcjy/torch_memory_saver#20
- [x] Support Multiple Stage Awake in SGLang:
sgl-project/sglang#7099
- [ ] Support Multiple Stage Awake in verl:
volcengine#1911

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

![Screenshot 2025-06-19 at 12 16
19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc)
![Screenshot 2025-06-19 at 12 13
14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Co-authored-by: Chayenne <[email protected]>
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
Co-authored with: MrAta ([email protected])

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

### Motivation

In RL Ecosystem which use colocate design like
[verl](https://github.com/volcengine/verl/tree/main), we need to offload
training model and load serving model & KV Cache frequently.


#### Background
- Currently SGLang is using
[torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) to
pause and resume.
- [torch_memory_saver](https://github.com/fzyzcjy/torch_memory_saver) is
a open source repo that provided easy to use api to hack **cudaMalloc**
and **cudaFree** to make sure the virtual address could be consistent
after pause and resume, which is critical to ensure CUDA Graph work.
- CUDA Graph is critical to make sure SGLang runs faster in decoding
phases.


#### Here is the current behavior of VERL + SGLang


![Image](https://github.com/user-attachments/assets/e87e7dd6-f223-4de6-8f07-915eb2030ea8)

1. During Training, we have training model and optimizer state in the
GPU Memory, and once training is done, we will offload optimizer state
to cpu and keep the model weights in GPU, which is needed in Update
Weight.
2. During Update Weight, we awake the SGLang engine, so those paused
memory of Model Weights and KV Cache will come back. Then we update
model from training model to serving model on the fly using the api:
`update_weights_in_tensor`
3. After Model being updated, we delete the training model from GPU
Memory.


Above design works pretty well so far, however, this would waste a big
chunk of GPU Memory during rollout, which could cause a few issues we've
seen so far:
- **Small KV Cache**: We need to use relative lower number of mem
fraction ratio (e.g: 0.6), hence our KV Cache has less tokens. Given KV
Cache has less tokens, we will hit `RuntimeError: Prefill out of memory.
Try to lower your batch size.` when we try prefill large number of
requests.
- **Out of Memory**: If we use mem fraction ratio 0.8 and run RL for 32B
model on 8 H100, it will OOM during update weight


#### Challenge
- `torch_memory_saver` currently only supports Singleton, hence SGLang
will pause and resume KV Cache + Weights together, they are treated as
the same group of memory controlled by the singleton
`torch_memory_saver` instance

#### Proposal

![Image](https://github.com/user-attachments/assets/7fda9638-0dc2-4c14-bc64-cd20616f350f)

1. During Training, we do the same
2. During Update Weight Stage 1, we awake the model weights from SGLang
and then update weights
3. During Update Weight Stage 2, we delete the training model weights
from GPU Memory
4. Awake the SGLang's KV Cache



![Image](https://github.com/user-attachments/assets/f3dab327-dc2e-4ed8-88d7-15e383f77d25)


### Benefit
With above feature, we can train larger model with same GPU, we can also
make training/rollout more efficient given we can allocate larger KV
Cache

### Solution: Keep using Singleton and provide tag based pause/resume

- [x] Support tag based resume/pause:
fzyzcjy/torch_memory_saver#20
- [x] Support Multiple Stage Awake in SGLang:
sgl-project/sglang#7099
- [ ] Support Multiple Stage Awake in verl:
volcengine#1911

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### API

> Demonstrate how the API changes if any.

### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

![Screenshot 2025-06-19 at 12 16
19 PM](https://github.com/user-attachments/assets/a95dd57e-43e1-4f28-8a84-003ec5c043fc)
![Screenshot 2025-06-19 at 12 13
14 PM](https://github.com/user-attachments/assets/f1f4a8a8-1845-4fad-9424-5526d4154dd0)


### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ ] New CI unit test(s) are added to cover the code path.
- [ ] Rely on existing unit tests on CI that covers the code path.

---------

Co-authored-by: Chayenne <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Support Multi-Stage Awake for RL

2 participants