Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple CPU offloading support. #2081

Merged
merged 7 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_hip,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -145,7 +146,9 @@ def __init__(
}
)

# Init componnets
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

# Init components
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
Expand Down
15 changes: 10 additions & 5 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


# Aligned with HF's implementation, using sliding window inclusive with the last token
Expand Down Expand Up @@ -267,11 +268,15 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Gemma2DecoderLayer(
layer_id=idx,
config=config,
cache_config=cache_config,
quant_config=quant_config,
),
prefix="",
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
15 changes: 8 additions & 7 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -255,14 +256,14 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: LlamaDecoderLayer(
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
),
prefix="model.layers",
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class OlmoAttention(nn.Module):
Expand Down Expand Up @@ -220,11 +221,13 @@ def __init__(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size
)
self.layers = nn.ModuleList(
[
OlmoDecoderLayer(config, layer_id, quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoDecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=False, bias=False
Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers


class OlmoeMoE(nn.Module):
Expand Down Expand Up @@ -261,11 +262,13 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: OlmoeDecoderLayer(
config=config,
quant_config=quant_config,
layer_id=idx,
),
)
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

Expand Down
13 changes: 8 additions & 5 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import make_layers

Qwen2Config = None

Expand Down Expand Up @@ -230,11 +231,13 @@ def __init__(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[
Qwen2DecoderLayer(config, i, quant_config=quant_config)
for i in range(config.num_hidden_layers)
]
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Qwen2DecoderLayer(
layer_id=idx,
config=config,
quant_config=quant_config,
),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ServerArgs:
max_prefill_tokens: int = 16384
schedule_policy: str = "lpm"
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0

# Other runtime options
tp_size: int = 1
Expand Down Expand Up @@ -373,6 +374,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)

parser.add_argument(
"--cpu-offload-gb",
type=int,
default=ServerArgs.cpu_offload_gb,
help="How many GBs of RAM to reserve for CPU offloading",
)

# Other runtime options
parser.add_argument(
"--tensor-parallel-size",
Expand Down
91 changes: 90 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import warnings
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union

import numpy as np
import psutil
Expand All @@ -44,6 +44,7 @@
from packaging import version as pkg_version
from starlette.routing import Mount
from torch import nn
from torch.func import functional_call
from torch.profiler import ProfilerActivity, profile, record_function
from triton.runtime.cache import (
FileCacheManager,
Expand Down Expand Up @@ -190,6 +191,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
return free_gpu_memory / (1 << 30)


def is_pin_memory_available() -> bool:
return torch.cuda.is_available()


_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = 0


def set_cpu_offload_max_bytes(max_bytes: int) -> None:
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
_CPU_OFFLOAD_BYTES = 0
_CPU_OFFLOAD_MAX_BYTES = max_bytes


def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device

if device == torch.device("cpu"):
return module

global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module

pin_memory = is_pin_memory_available()
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters = False
for p in module.parameters():
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
# we use per-parameter offloading
# one module might have some parameters offloaded and some not
break

# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True

if offloaded_parameters:
original_forward = module.forward

def forward(*args, **kwargs):
module.forward = original_forward
device_state = {
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
module.forward = forward
return output

module.forward = forward

return module


class LayerFn(Protocol):

def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...


def make_layers(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function"""
modules = torch.nn.ModuleList(
[
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
for idx in range(num_hidden_layers)
]
)
return modules


def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
Expand Down
30 changes: 30 additions & 0 deletions test/srt/test_srt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,36 @@ def test_7_engine_offline_throughput(self):
result = throughput_test(server_args=server_args, bench_args=bench_args)
self.assertGreater(result["total_throughput"], 3500)

def test_8_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST

sampling_params = {"temperature": 0, "max_new_tokens": 8}

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
)
out1 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()

engine = sgl.Engine(
model_path=model_path,
random_seed=42,
max_total_tokens=128,
cpu_offload_gb=3,
)
out2 = engine.generate(prompt, sampling_params)["text"]
engine.shutdown()

print("==== Answer 1 ====")
print(out1)

print("==== Answer 2 ====")
print(out2)
self.assertEqual(out1, out2)


if __name__ == "__main__":
unittest.main()
Loading