Skip to content
Merged
9 changes: 8 additions & 1 deletion python/sglang/srt/disaggregation/base/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy.typing as npt

from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs


class KVArgs:
Expand All @@ -16,6 +17,7 @@ class KVArgs:
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str
gpu_id: int


class KVPoll:
Expand All @@ -30,7 +32,12 @@ class BaseKVManager(ABC):
"""Base class for managing transfers states"""

@abstractmethod
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): ...
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
): ...


class BaseKVSender(ABC):
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def _init_kv_manager(self) -> BaseKVManager:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
kv_manager = kv_manager_class(kv_args, DisaggregationMode.DECODE)
kv_manager = kv_manager_class(
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
)
return kv_manager

def add(self, req: Req) -> None:
Expand Down
Loading
Loading