Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LoadConfig:
ignore_patterns: Optional[Union[List[str], str]] = None
decryption_key_file: Optional[str] = None
decrypt_max_concurrency: int = -1
tp_rank: Optional[int] = None
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
Expand Down
21 changes: 0 additions & 21 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,6 @@ def __init__(
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
tp_rank: Optional[int] = None,
remote_instance_weight_loader_seed_instance_ip: Optional[str] = None,
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None,
remote_instance_weight_loader_send_weights_group_ports: Optional[
List[int]
] = None,
) -> None:
# Parse args
self.model_path = model_path
Expand All @@ -78,18 +72,6 @@ def __init__(
self.is_draft_model = is_draft_model
self.model_impl = model_impl

# TODO: remove these fields
self.tp_rank = tp_rank
self.remote_instance_weight_loader_seed_instance_ip = (
remote_instance_weight_loader_seed_instance_ip
)
self.remote_instance_weight_loader_seed_instance_service_port = (
remote_instance_weight_loader_seed_instance_service_port
)
self.remote_instance_weight_loader_send_weights_group_ports = (
remote_instance_weight_loader_send_weights_group_ports
)

# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
self.model_override_args = json.loads(model_override_args)
Expand Down Expand Up @@ -204,9 +186,6 @@ def from_server_args(
quantization=server_args.quantization,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
remote_instance_weight_loader_seed_instance_ip=server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=server_args.remote_instance_weight_loader_send_weights_group_ports,
**kwargs,
)

Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def __init__(
else server_args.speculative_draft_model_revision
),
is_draft_model=is_draft_worker,
tp_rank=tp_rank,
)

self.model_runner = ModelRunner(
Expand Down
10 changes: 7 additions & 3 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
Expand All @@ -112,9 +115,6 @@
set_offloader,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
Expand Down Expand Up @@ -743,6 +743,10 @@ def load_model(self):
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
tp_rank=self.tp_rank,
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
Expand Down
19 changes: 10 additions & 9 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
Expand All @@ -77,9 +80,6 @@
safetensors_weights_iterator,
set_runai_streamer_env,
)
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
Expand Down Expand Up @@ -1420,7 +1420,7 @@ def load_model(
f"load format {load_config.load_format}"
)

model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"

with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
Expand All @@ -1442,11 +1442,12 @@ def load_model(
def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
load_config = self.load_config
instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time()
client.build_group(
gpu_id=device_config.gpu_id,
tp_rank=model_config.tp_rank,
tp_rank=load_config.tp_rank,
instance_ip=instance_ip,
)
torch.cuda.synchronize()
Expand All @@ -1455,13 +1456,13 @@ def load_model_from_remote_instance(
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
)

if model_config.tp_rank == 0:
if load_config.tp_rank == 0:
t = threading.Thread(
target=trigger_transferring_weights_request,
args=(
model_config.remote_instance_weight_loader_seed_instance_ip,
model_config.remote_instance_weight_loader_seed_instance_service_port,
model_config.remote_instance_weight_loader_send_weights_group_ports,
load_config.remote_instance_weight_loader_seed_instance_ip,
load_config.remote_instance_weight_loader_seed_instance_service_port,
load_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
Expand Down
Loading