(WIP) Async LoRA prefetch - add scheduler logic for lora prefetch#13828
(WIP) Async LoRA prefetch - add scheduler logic for lora prefetch#13828glenliu21 wants to merge 1 commit intosgl-project:mainfrom
Conversation
Summary of ChangesHello @glenliu21, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request lays the groundwork for asynchronous LoRA prefetching, a strategy designed to improve the efficiency of serving requests that utilize LoRA adapters. By proactively identifying and loading LoRAs for requests waiting in the queue, the system aims to reduce latency and enhance overall throughput. The changes primarily involve integrating prefetching logic into the scheduler and LoRA management components, along with adding profiling capabilities to monitor the performance impact of these operations. This is the first step in a multi-part implementation. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces the scheduler-side logic for asynchronous LoRA prefetching, which is a good step towards improving performance for multi-LoRA workloads. The addition of profiling code for LoRA loading and batch execution is also very helpful for performance analysis.
My review focuses on improving code maintainability by reducing duplication, fixing a potential bug in ForwardBatch creation, and addressing some minor code quality issues.
| def prefetch_lora_adapters(self, prefetch_lora_batch: ModelWorkerBatch): | ||
| prefetch_fwd = ForwardBatch( | ||
| forward_mode=prefetch_lora_batch.forward_mode, | ||
| batch_size=len(prefetch_lora_batch.seq_lens), | ||
| input_ids=prefetch_lora_batch.input_ids, | ||
| req_pool_indices=prefetch_lora_batch.req_pool_indices, | ||
| seq_lens=prefetch_lora_batch.seq_lens, | ||
| out_cache_loc=prefetch_lora_batch.out_cache_loc, | ||
| seq_lens_sum=prefetch_lora_batch.seq_lens_sum, | ||
| seq_lens_cpu=prefetch_lora_batch.seq_lens_cpu, | ||
| orig_seq_lens=prefetch_lora_batch.orig_seq_lens, | ||
| lora_ids=prefetch_lora_batch.lora_ids, | ||
| ) | ||
| assert isinstance(prefetch_lora_batch.extend_seq_lens, list) | ||
| prefetch_fwd.extend_seq_lens = torch.tensor( | ||
| prefetch_lora_batch.extend_seq_lens, dtype=torch.int32 | ||
| ).to(self.model_runner.device, non_blocking=True) | ||
| prefetch_fwd.extend_seq_lens_cpu = prefetch_lora_batch.extend_seq_lens | ||
|
|
||
| result = self.model_runner.prefetch_lora_batch(prefetch_fwd) | ||
| return result |
There was a problem hiding this comment.
The manual creation of ForwardBatch here is brittle and incomplete. For instance, it's missing the positions tensor, which is required by some attention backends and is normally computed in ForwardBatch.init_new.
To make this more robust, I suggest using ForwardBatch.init_new. To prevent lora_manager.prepare_lora_batch from being called twice, you could add a prepare_lora: bool = True parameter to ForwardBatch.init_new and call it with prepare_lora=False here.
Here's how you could modify ForwardBatch.init_new in python/sglang/srt/model_executor/forward_batch_info.py:
# In ForwardBatch.init_new
def init_new(cls, batch: ModelWorkerBatch, model_runner: ModelRunner, prepare_lora: bool = True):
# ...
# Init lora information
if model_runner.server_args.enable_lora and prepare_lora:
model_runner.lora_manager.prepare_lora_batch(ret)
# ...Then, you can simplify prefetch_lora_adapters as suggested.
def prefetch_lora_adapters(self, prefetch_lora_batch: ModelWorkerBatch):
prefetch_fwd = ForwardBatch.init_new(
prefetch_lora_batch, self.model_runner, prepare_lora=False
)
result = self.model_runner.prefetch_lora_batch(prefetch_fwd)
return result| def prepare_for_lora_prefetch(self): | ||
| """Taken mainly from prepare_for_extend()""" | ||
| self.forward_mode = ForwardMode.EXTEND | ||
|
|
||
| # Init tensors | ||
| reqs = self.reqs | ||
| input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] | ||
| extend_num_tokens = sum(len(ids) for ids in input_ids) | ||
| seq_lens = [len(r.fill_ids) for r in reqs] | ||
| orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs] | ||
| prefix_lens = [len(r.prefix_indices) for r in reqs] | ||
| extend_lens = [r.extend_input_len for r in reqs] | ||
|
|
||
| input_ids_tensor = torch.tensor( | ||
| list(chain.from_iterable(input_ids)), dtype=torch.int64 | ||
| ).to(self.device, non_blocking=True) | ||
| seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( | ||
| self.device, non_blocking=True | ||
| ) | ||
| seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) | ||
| orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( | ||
| self.device, non_blocking=True | ||
| ) | ||
|
|
||
| # Set batch fields needed by alloc_for_extend | ||
| self.prefix_lens = prefix_lens | ||
| self.extend_lens = extend_lens | ||
| self.seq_lens = seq_lens_tensor | ||
| self.seq_lens_cpu = seq_lens_cpu | ||
| self.extend_num_tokens = extend_num_tokens | ||
|
|
||
| self.input_ids = input_ids_tensor | ||
| self.orig_seq_lens = orig_seq_lens_tensor | ||
| self.seq_lens_sum = sum(seq_lens) | ||
|
|
There was a problem hiding this comment.
The new method prepare_for_lora_prefetch duplicates a significant amount of code from prepare_for_extend. To improve maintainability and avoid code duplication, consider refactoring the common tensor initialization logic into a private helper method. This helper could then be called by both prepare_for_lora_prefetch and prepare_for_extend.
| prefetch_batch.get_model_worker_batch() | ||
| ) | ||
|
|
||
| print(f"current batch lora ids: {running_batch_lora_ids}") |
There was a problem hiding this comment.
This print statement appears to be for debugging. It's better to use the logging module (e.g., logger.debug(...)) for such messages. This allows for better control over log levels and output streams in different environments.
| print(f"current batch lora ids: {running_batch_lora_ids}") | |
| logger.debug(f"current batch lora ids: {running_batch_lora_ids}") |
| has_lora = hasattr(batch, "lora_ids") and batch.lora_ids | ||
| lora_info = ( | ||
| f", lora_ids={len(set(batch.lora_ids)) if has_lora else 0}" | ||
| if has_lora | ||
| else "" | ||
| ) |
There was a problem hiding this comment.
The batch object here is a ScheduleBatch, which does not have a lora_ids attribute. This causes has_lora to always be False, and the LoRA information is never logged. You can get the LoRA IDs by iterating through batch.reqs.
lora_ids = [req.lora_id for req in batch.reqs if req.lora_id]
lora_info = f", lora_ids={len(set(lora_ids))}" if lora_ids else ""8ee429b to
7fd8103
Compare
02a3d11 to
c20b58e
Compare
|
Moved to #14190. |
Motivation
This is the first PR for #8712. In this PR, we use the prefetch policy used in S-Lora, where we prefetch LoRA adapters based on what requests are on the Scheduler's waiting queue.
Modifications
ForwardBatchas a LoRA prefetch batch, which consists of requests that are next to be ran on the waiting queueLoRAManager, the memory pool, and the LoRA backendAccuracy Tests
Benchmarking and Profiling
Checklist