@@ -1303,6 +1303,24 @@ def print(self):
13031303 logger .info ("=============================================================" )
13041304
13051305
1306+ class RouterConfig :
1307+ """
1308+ Configuration for router
1309+ Attributes:
1310+ router: the url of router, such as http://127.0.0.1:8000
1311+ api_server_host: the host ip of model server
1312+ api_server_port: the http port of model server
1313+ """
1314+
1315+ def __init__ (self , args : dict ):
1316+ self .router = args ["router" ]
1317+ if self .router is not None and not self .router .startswith (("http://" , "https://" )):
1318+ self .router = f"http://{ self .router } "
1319+
1320+ self .api_server_host = get_host_ip ()
1321+ self .api_server_port = args ["port" ]
1322+
1323+
13061324class CommitConfig :
13071325 """
13081326 Configuration for tracking version information from version.txt
@@ -1404,6 +1422,7 @@ def __init__(
14041422 speculative_config : SpeculativeConfig = None ,
14051423 eplb_config : EPLBConfig = None ,
14061424 structured_outputs_config : StructuredOutputsConfig = None ,
1425+ router_config : RouterConfig = None ,
14071426 tokenizer : str = None ,
14081427 ips : str = None ,
14091428 use_warmup : bool = False ,
@@ -1416,7 +1435,6 @@ def __init__(
14161435 early_stop_config : Optional [Dict [str , Any ]] = None ,
14171436 tool_parser : str = None ,
14181437 test_mode = False ,
1419- port = None ,
14201438 ):
14211439 self .model_config : ModelConfig = model_config # type: ignore
14221440 self .cache_config : CacheConfig = cache_config # type: ignore
@@ -1432,6 +1450,7 @@ def __init__(
14321450 self .cache_config : CacheConfig = cache_config # type: ignore
14331451 self .plas_attention_config : Optional [PlasAttentionConfig ] = plas_attention_config
14341452 self .structured_outputs_config : StructuredOutputsConfig = structured_outputs_config
1453+ self .router_config : RouterConfig = router_config
14351454
14361455 # Initialize cuda graph capture list
14371456 max_capture_shape = self .scheduler_config .max_num_seqs
@@ -1459,7 +1478,6 @@ def __init__(
14591478 self .ips = self .ips .split ("," )
14601479
14611480 self .host_ip = get_host_ip ()
1462- self .port = port
14631481
14641482 if self .ips is None :
14651483 self .nnode = 1
@@ -1730,39 +1748,39 @@ def init_cache_info(self):
17301748 """
17311749 initialize cache info
17321750 """
1733- # TODO: group the splitiwse params
1751+ # TODO: group the splitiwse params, remove code of v0
17341752 # v0 requires prefill and decode in one node and it uses local scheduler
17351753 # v1 supports prefill and decode in multi node and it uses splitwise or dp scheduler
17361754 # v2 supports prefill and decode in multi node and it uses router and local scheduler
17371755 self .splitwise_version = None
1738- if self .scheduler_config .name == "local" and self .scheduler_config . router is None :
1756+ if self .scheduler_config .name == "local" and ( self .router_config is None or self . router_config . router is None ) :
17391757 self .splitwise_version = "v0"
17401758 elif self .scheduler_config .name in ("splitwise" , "dp" ):
17411759 self .splitwise_version = "v1"
1742- elif self .scheduler_config .name == "local" and self .scheduler_config .router :
1760+ elif self .scheduler_config .name == "local" and self .router_config and self . router_config .router :
17431761 self .splitwise_version = "v2"
17441762 else :
17451763 raise ValueError (
17461764 f"Unsupported scheduler mode, scheduler_name: { self .scheduler_config .name } , "
1747- f"router : { self .scheduler_config . router } "
1765+ f"router_config : { self .router_config } "
17481766 )
17491767 logger .info (f"splitwise_version: { self .splitwise_version } " )
17501768
1769+ if isinstance (self .parallel_config .engine_worker_queue_port , (int , str )):
1770+ engine_worker_queue_port = self .parallel_config .engine_worker_queue_port
1771+ else :
1772+ engine_worker_queue_port = self .parallel_config .engine_worker_queue_port [
1773+ self .parallel_config .local_data_parallel_id
1774+ ]
1775+ connector_port = self .cache_config .pd_comm_port [0 ] if self .cache_config .pd_comm_port else None
1776+
17511777 self .disaggregate_info = {}
17521778 if self .scheduler_config .splitwise_role != "mixed" :
17531779 self .disaggregate_info ["role" ] = self .scheduler_config .splitwise_role
17541780 self .disaggregate_info ["cache_info" ] = dict ()
17551781 current_protocol = self .cache_config .cache_transfer_protocol .split ("," )
17561782 self .disaggregate_info ["transfer_protocol" ] = current_protocol
17571783
1758- if isinstance (self .parallel_config .engine_worker_queue_port , (int , str )):
1759- engine_worker_queue_port = self .parallel_config .engine_worker_queue_port
1760- else :
1761- engine_worker_queue_port = self .parallel_config .engine_worker_queue_port [
1762- self .parallel_config .local_data_parallel_id
1763- ]
1764- connector_port = self .cache_config .pd_comm_port [0 ] if self .cache_config .pd_comm_port else None
1765-
17661784 for protocol in current_protocol :
17671785 if protocol == "ipc" :
17681786 self .disaggregate_info ["cache_info" ][protocol ] = {
@@ -1778,17 +1796,18 @@ def init_cache_info(self):
17781796 }
17791797 logger .info (f"disaggregate_info: { self .disaggregate_info } " )
17801798
1781- self .splitwise_instance_info = {
1799+ if self .router_config :
1800+ self .register_info = {
17821801 "role" : self .scheduler_config .splitwise_role ,
17831802 "host_ip" : self .host_ip ,
1784- "port" : self .port ,
1803+ "port" : self .router_config . api_server_port ,
17851804 "connector_port" : connector_port ,
17861805 "rdma_ports" : self .cache_config .rdma_comm_ports ,
17871806 "engine_worker_queue_port" : engine_worker_queue_port ,
17881807 "device_ids" : self .local_device_ids ,
17891808 "transfer_protocol" : self .cache_config .cache_transfer_protocol .split ("," ),
17901809 }
1791- logger .info (f"splitwise_instance_info : { self .splitwise_instance_info } " )
1810+ logger .info (f"register_info : { self .register_info } " )
17921811
17931812 def read_from_config (self ):
17941813 """
0 commit comments