Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions python/sglang/srt/managers/cache_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,26 @@
logger = logging.getLogger(__name__)


class LayerDoneCounter:
def __init__(self, num_layers):
self.counter = num_layers
self.condition = threading.Condition()

def increment(self):
with self.condition:
self.counter += 1
self.condition.notify_all()

def wait_until(self, threshold):
with self.condition:
while self.counter <= threshold:
self.condition.wait()

def reset(self):
with self.condition:
self.counter = 0


class CacheOperation:

counter = 0
Expand Down Expand Up @@ -132,13 +152,18 @@ def __init__(
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host
self.write_policy = write_policy

self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)

if write_policy not in [
"write_through",
"write_through_selective",
Expand All @@ -165,7 +190,7 @@ def __init__(
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.write_thread.start()
self.load_thread.start()
Expand All @@ -186,7 +211,7 @@ def reset(self):
target=self.write_thread_func_buffer, daemon=True
)
self.load_thread = threading.Thread(
target=self.load_thread_func_buffer, daemon=True
target=self.load_thread_func_layer_by_layer, daemon=True
)
self.stop_event.clear()
self.write_thread.start()
Expand Down Expand Up @@ -273,6 +298,42 @@ def load_thread_func_direct(self):
except Exception as e:
logger.error(e)

def load_thread_func_layer_by_layer(self):
"""
Load KV caches from host memory to device memory layer by layer.
"""
with torch.cuda.stream(self.load_stream):
while not self.stop_event.is_set():
self.load_cache_event.wait(timeout=1)
if not self.load_cache_event.is_set():
continue
self.load_cache_event.clear()

batch_operation = None
while self.load_queue.qsize() > 0:
op = self.load_queue.get(block=True)
if batch_operation is None:
batch_operation = op
else:
batch_operation.merge(op)
if batch_operation is None:
continue

self.layer_done_counter.reset()
for i in range(self.mem_pool_host.layer_num):
flat_data = self.mem_pool_host.get_flat_data_by_layer(
batch_operation.host_indices, i
)
self.mem_pool_device.transfer_per_layer(
batch_operation.device_indices, flat_data, i
)
self.layer_done_counter.increment()

self.mem_pool_host.complete_io(batch_operation.host_indices)
for node_id in batch_operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)

def write_aux_func(self, no_wait=False):
"""
Auxiliary function to prepare the buffer for write operations.
Expand Down
20 changes: 16 additions & 4 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def __init__(
# The relative logprob_start_len in an extend batch
self.extend_logprob_start_len = 0
self.last_node = None
self.last_node_global = None

# Whether or not if it is chunked. It increments whenever
# it is chunked, and decrement whenever chunked request is
Expand Down Expand Up @@ -389,13 +390,24 @@ def finished(self) -> bool:
# Whether request reached finished condition
return self.finished_reason is not None

def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
def init_next_round_input(
self,
tree_cache: Optional[BasePrefixCache] = None,
enable_hierarchical_cache=False,
):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
# tree cache is None if the prefix is not computed with tree cache.
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
if enable_hierarchical_cache:
self.prefix_indices, self.last_node, self.last_node_global = (
tree_cache.match_prefix(
key=self.adjust_max_prefix_ids(), include_evicted=True
)
)
else:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
Expand Down
35 changes: 30 additions & 5 deletions python/sglang/srt/managers/schedule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,15 @@ class CacheAgnosticPolicy(Enum):
class SchedulePolicy:
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]

def __init__(self, policy: str, tree_cache: BasePrefixCache):
def __init__(
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
self.enable_hierarchical_cache = enable_hierarchical_cache

# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
Expand Down Expand Up @@ -149,9 +155,14 @@ def _compute_prefix_matches(
prefix_ids = r.adjust_max_prefix_ids()

# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)
if self.enable_hierarchical_cache:
r.prefix_indices, r.last_node, r.last_node_global = (
self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
)
else:
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=prefix_ids
)

# NOTE(sang): This logic is for in-batch prefix caching;
# If there are more than 1 request that have small matching prefix from
Expand Down Expand Up @@ -428,7 +439,9 @@ def add_req_state(r, insert_sort=False):

return self.budget_state()

def add_one_req(self, req: Req, has_chunked_req: bool):
def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
return self.add_one_req_ignore_eos(req, has_chunked_req)

Expand All @@ -448,6 +461,18 @@ def add_one_req(self, req: Req, has_chunked_req: bool):
if total_tokens > self.rem_total_tokens:
return AddReqResult.NO_TOKEN

if (
enable_hierarchical_cache
and req.last_node_global is not None
and req.last_node_global.evicted
):
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node_global, req.prefix_indices
)
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)

if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
# Non-chunked prefill
self.can_run_list.append(req)
Expand Down
47 changes: 19 additions & 28 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,10 @@ def __init__(
f"context_len={self.model_config.context_len}"
)

# Init memory pool and cache
self.init_memory_pool_and_cache()

# Init running status
self.waiting_queue: List[Req] = []
self.staging_reqs = {}
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
# The current forward batch
Expand Down Expand Up @@ -308,7 +306,9 @@ def __init__(
self.grammar_backend = None

# Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
Expand Down Expand Up @@ -431,6 +431,7 @@ def init_memory_pool_and_cache(self):
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
)
else:
self.tree_cache = RadixCache(
Expand Down Expand Up @@ -1005,6 +1006,11 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.batch_is_full = True
return None

if self.enable_hierarchical_cache:
# check for completion of hierarchical cache activities to release memory
self.tree_cache.writing_check()
self.tree_cache.loading_check()

# Get priority queue
prefix_computed = self.policy.calc_priority(self.waiting_queue)

Expand Down Expand Up @@ -1048,32 +1054,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.batch_is_full = True
break

req.init_next_round_input(None if prefix_computed else self.tree_cache)
req.init_next_round_input(
None if prefix_computed else self.tree_cache,
self.enable_hierarchical_cache,
)

if self.enable_hierarchical_cache and req.last_node is not None:
if req.last_node.evicted:
# loading KV cache for the request
req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
req.last_node,
req.prefix_indices,
adder.rem_total_tokens,
)
if req.last_node.loading:
# to prevent frequent cache invalidation
if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
self.tree_cache.inc_lock_ref(req.last_node)
self.staging_reqs[req.rid] = req.last_node
continue
elif req.last_node.loading:
if not self.tree_cache.loading_complete(req.last_node):
continue

if req.rid in self.staging_reqs:
self.tree_cache.dec_lock_ref(self.staging_reqs[req.rid])
del self.staging_reqs[req.rid]

res = adder.add_one_req(req, self.chunked_req)
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
Expand All @@ -1094,6 +1082,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
x for x in self.waiting_queue if x not in set(can_run_list)
]

if self.enable_hierarchical_cache:
self.tree_cache.read_to_load_cache()

if adder.new_chunked_req is not None:
assert self.chunked_req is None
self.chunked_req = adder.new_chunked_req
Expand Down
Loading
Loading