-
Notifications
You must be signed in to change notification settings - Fork 5k
[Feature] overlap LoRA weight loading with compute #15512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
2c9353d
overlap lora weight loading with compute
glenliu21 1aebbc7
move logic to LoRAPrefetcher
glenliu21 bff918e
add lora_prefetcher.py
glenliu21 bc9cf43
merge main
glenliu21 b442c8b
fix weight sync issue
glenliu21 8a674b5
fix
glenliu21 ee707d6
Merge branch 'main' into lora_pipeline
glenliu21 071465f
Merge branch 'main' into lora_pipeline
glenliu21 28f8fe9
fix
glenliu21 58edc3a
precommit
glenliu21 5351cb2
add server arg and test
glenliu21 8209420
register test for ci
glenliu21 81072e6
adjust test
glenliu21 ffeeacb
Merge branch 'main' into lora_pipeline
glenliu21 044d789
improve test
glenliu21 fb31cf7
Merge branch 'main' into lora_pipeline
glenliu21 db1a67e
add tp test
glenliu21 daa8403
Merge branch 'main' into lora_pipeline
glenliu21 ea43f6d
rename lora_prefetcher to lora_overlap_loader; reorganize and reforma…
glenliu21 30930f5
Merge branch 'main' into lora_pipeline
glenliu21 2892f63
Merge branch 'main' into lora_pipeline
glenliu21 1500be6
Merge branch 'main' into lora_pipeline
Fridge003 3293996
Merge branch 'main' into lora_pipeline
glenliu21 f87bdf6
max_loaded_loras fix
glenliu21 354ff9c
Merge branch 'main' into lora_pipeline
glenliu21 77609c0
Merge branch 'main' into lora_pipeline
glenliu21 8b69562
fix server arg description
glenliu21 0aa9226
max_loaded_loras arg
glenliu21 0409f43
Merge branch 'main' into lora_pipeline
glenliu21 ec6cecf
Merge branch 'main' into lora_pipeline
glenliu21 ae79420
Merge branch 'main' into lora_pipeline
glenliu21 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| import logging | ||
| from enum import Enum, auto | ||
| from typing import Dict, Optional | ||
|
|
||
| import torch | ||
| from torch.cuda import Event as CudaEvent | ||
| from torch.cuda import Stream as CudaStream | ||
| from torch.cuda import StreamContext as CudaStreamContext | ||
|
|
||
| from sglang.srt.lora.lora_manager import LoRAManager | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class LoRAOverlapLoadStatus(Enum): | ||
| LOADED = auto() | ||
| LOADING = auto() | ||
| NOT_LOADED = auto() | ||
|
|
||
|
|
||
| class LoRAOverlapLoader: | ||
| def __init__(self, lora_manager): | ||
| self.lora_manager: LoRAManager = lora_manager | ||
| self.device_module = torch.get_device_module(self.lora_manager.device) | ||
| self.load_stream: CudaStream = self.device_module.Stream() | ||
| self.load_stream_context: CudaStreamContext = self.device_module.stream( | ||
| self.load_stream | ||
| ) | ||
| self.lora_to_overlap_load_event: Dict[Optional[str], CudaEvent] = {} | ||
|
|
||
| def try_overlap_load_lora( | ||
| self, lora_id: Optional[str], running_loras: set[Optional[str]] | ||
| ) -> bool: | ||
| """ | ||
| Check a LoRA adapter's asynchronous load status, and try to load it if there's capacity | ||
| in the memory pool. Returns whether or not the adapter has been loaded. | ||
| """ | ||
| lora_pipeline_load_status = self._check_overlap_load_status(lora_id) | ||
| if lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADING: | ||
| return False | ||
| elif lora_pipeline_load_status == LoRAOverlapLoadStatus.NOT_LOADED: | ||
| res = self._try_start_overlap_load(lora_id, running_loras) | ||
| if res: | ||
| logger.debug(f"Loading LoRA adapter {lora_id} asynchronously") | ||
|
|
||
| return False | ||
| else: | ||
| assert lora_pipeline_load_status == LoRAOverlapLoadStatus.LOADED | ||
| return True | ||
|
|
||
| def _check_overlap_load_status( | ||
| self, lora_id: Optional[str] | ||
| ) -> LoRAOverlapLoadStatus: | ||
| if lora_id not in self.lora_to_overlap_load_event: | ||
| return LoRAOverlapLoadStatus.NOT_LOADED | ||
|
|
||
| event = self.lora_to_overlap_load_event[lora_id] | ||
|
|
||
| if not event.query(): | ||
| return LoRAOverlapLoadStatus.LOADING | ||
|
|
||
| torch.cuda.current_stream().wait_event(event) | ||
| del self.lora_to_overlap_load_event[lora_id] | ||
|
|
||
| return LoRAOverlapLoadStatus.LOADED | ||
|
|
||
| def _try_start_overlap_load( | ||
| self, lora_id: Optional[str], running_loras: set[Optional[str]] | ||
| ) -> bool: | ||
| loras_to_be_loaded = running_loras | self.lora_to_overlap_load_event.keys() | ||
|
|
||
| new_lora_set = {lora_id} | loras_to_be_loaded | ||
| if not self.lora_manager.validate_lora_batch(new_lora_set): | ||
| return False | ||
|
|
||
| with self.load_stream_context: | ||
| self.lora_manager.fetch_new_loras({lora_id}, loras_to_be_loaded) | ||
| event = self.device_module.Event() | ||
| event.record(self.load_stream) | ||
|
|
||
| self.lora_to_overlap_load_event[lora_id] = event | ||
| return True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add an example for lora overlap loading in below section (can be updated in a following PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see #17464.