Skip to content

Commit 24180fb

Browse files
authored
[FDConfig]Remove splitwise_role and engine_worker_queue_port in FDConfig (#4147)
* remove splitwise_role and engine_worker_queue_port * fix xpu * fix xpu * fix xpu * fix unittest * resolve conflct
1 parent ee9d8a8 commit 24180fb

23 files changed

+129
-89
lines changed

fastdeploy/config.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,6 @@ def __init__(
296296
# Do profile or not
297297
self.do_profile: bool = False
298298

299-
# splitwise role
300-
self.splitwise_role: str = "mixed"
301299
# guided decoding backend
302300
self.guided_decoding_backend: str = None
303301
# disable any whitespace for guided decoding
@@ -319,14 +317,6 @@ def __init__(
319317
else:
320318
self.expert_parallel_size = 1
321319
self.use_ep = self.expert_parallel_size > 1
322-
if self.splitwise_role == "mixed":
323-
self.moe_phase = MoEPhase(phase="prefill")
324-
elif self.splitwise_role == "prefill":
325-
self.moe_phase = MoEPhase(phase="prefill")
326-
elif self.splitwise_role == "decode":
327-
self.moe_phase = MoEPhase(phase="decode")
328-
else:
329-
raise NotImplementedError
330320

331321
# pd_disaggregation
332322
use_pd_disaggregation: int = int(os.getenv("FLAGS_use_pd_disaggregation", 0))
@@ -1116,10 +1106,8 @@ def __init__(
11161106
max_model_len: int = 8192,
11171107
ips: str = None,
11181108
use_warmup: bool = False,
1119-
engine_worker_queue_port: str = "8002",
11201109
limit_mm_per_prompt: Optional[Dict[str, Any]] = None,
11211110
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
1122-
splitwise_role: str = "mixed",
11231111
innode_prefill_ports: Optional[List[int]] = None,
11241112
max_num_partial_prefills: int = 1,
11251113
max_long_partial_prefills: int = 1,
@@ -1182,19 +1170,14 @@ def __init__(
11821170
self.limit_mm_per_prompt = limit_mm_per_prompt
11831171
self.mm_processor_kwargs = mm_processor_kwargs
11841172
self.use_warmup = use_warmup
1185-
self.splitwise_role = splitwise_role
11861173
self.innode_prefill_ports = innode_prefill_ports
11871174
self.max_num_partial_prefills = max_num_partial_prefills
11881175
self.max_long_partial_prefills = max_long_partial_prefills
11891176
self.long_prefill_token_threshold = long_prefill_token_threshold
11901177
self.reasoning_parser = reasoning_parser
11911178
self.guided_decoding_backend = guided_decoding_backend
11921179
self.disable_any_whitespace = disable_any_whitespace
1193-
self.engine_worker_queue_port = engine_worker_queue_port
11941180
self._str_to_list("innode_prefill_ports", int)
1195-
if isinstance(engine_worker_queue_port, int):
1196-
self.engine_worker_queue_port = str(engine_worker_queue_port)
1197-
self._str_to_list("engine_worker_queue_port", str)
11981181

11991182
if envs.FD_FOR_TORCH_MODEL_FORMAT:
12001183
self.model_config.model_format = "torch"
@@ -1267,6 +1250,15 @@ def postprocess(self):
12671250
else:
12681251
self.guided_decoding_backend = "xgrammar"
12691252

1253+
if self.scheduler_config.splitwise_role == "mixed":
1254+
self.model_config.moe_phase = MoEPhase(phase="prefill")
1255+
elif self.scheduler_config.splitwise_role == "prefill":
1256+
self.model_config.moe_phase = MoEPhase(phase="prefill")
1257+
elif self.scheduler_config.splitwise_role == "decode":
1258+
self.model_config.moe_phase = MoEPhase(phase="decode")
1259+
else:
1260+
raise NotImplementedError
1261+
12701262
def check(self):
12711263
"""
12721264
check the legality of config
@@ -1301,7 +1293,7 @@ def check(self):
13011293
f"max_long_partial_prefills: {self.max_long_partial_prefills} should "
13021294
f"be less than or equal to max_num_partial_prefills: {self.max_num_partial_prefills}"
13031295
)
1304-
assert self.splitwise_role in ["mixed", "prefill", "decode"]
1296+
assert self.scheduler_config.splitwise_role in ["mixed", "prefill", "decode"]
13051297
# TODO(@wufeisheng): TP and EP need to be supported simultaneously.
13061298
assert (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.expert_parallel_size >= 1) or (
13071299
self.parallel_config.tensor_parallel_size >= 1 and self.parallel_config.expert_parallel_size == 1
@@ -1387,16 +1379,18 @@ def init_cache_info(self):
13871379
initialize cache info
13881380
"""
13891381
disaggregate_info = {}
1390-
if self.splitwise_role != "mixed":
1391-
disaggregate_info["role"] = self.splitwise_role
1382+
if self.scheduler_config.splitwise_role != "mixed":
1383+
disaggregate_info["role"] = self.scheduler_config.splitwise_role
13921384
disaggregate_info["cache_info"] = dict()
13931385
current_protocol = self.cache_config.cache_transfer_protocol.split(",")
13941386
disaggregate_info["transfer_protocol"] = current_protocol
13951387
for protocol in current_protocol:
13961388
if protocol == "ipc":
13971389
disaggregate_info["cache_info"][protocol] = {
13981390
"ip": self.host_ip,
1399-
"port": self.engine_worker_queue_port[self.parallel_config.local_data_parallel_id],
1391+
"port": self.parallel_config.engine_worker_queue_port[
1392+
self.parallel_config.local_data_parallel_id
1393+
],
14001394
"device_ids": self.local_device_ids,
14011395
}
14021396
elif protocol == "rdma":

fastdeploy/engine/args_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,11 @@ def create_engine_config(self) -> FDConfig:
10191019
else:
10201020
self.max_num_batched_tokens = self.max_model_len
10211021

1022+
if isinstance(self.engine_worker_queue_port, int):
1023+
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
1024+
if isinstance(self.engine_worker_queue_port, str):
1025+
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
1026+
10221027
all_dict = asdict(self)
10231028
all_dict["model_cfg"] = model_cfg
10241029
cache_cfg = CacheConfig(all_dict)
@@ -1032,11 +1037,6 @@ def create_engine_config(self) -> FDConfig:
10321037
early_stop_cfg = self.create_early_stop_config()
10331038
early_stop_cfg.update_enable_early_stop(self.enable_early_stop)
10341039

1035-
if isinstance(self.engine_worker_queue_port, int):
1036-
self.engine_worker_queue_port = str(self.engine_worker_queue_port)
1037-
if isinstance(self.engine_worker_queue_port, str):
1038-
self.engine_worker_queue_port = self.engine_worker_queue_port.split(",")
1039-
10401040
assert is_port_available(
10411041
"0.0.0.0", int(self.engine_worker_queue_port[parallel_cfg.local_data_parallel_id])
10421042
), f"The parameter `engine_worker_queue_port`:{self.engine_worker_queue_port} is already in use."
@@ -1052,12 +1052,10 @@ def create_engine_config(self) -> FDConfig:
10521052
speculative_config=speculative_cfg,
10531053
ips=self.ips,
10541054
use_warmup=self.use_warmup,
1055-
engine_worker_queue_port=self.engine_worker_queue_port,
10561055
limit_mm_per_prompt=self.limit_mm_per_prompt,
10571056
mm_processor_kwargs=self.mm_processor_kwargs,
10581057
reasoning_parser=self.reasoning_parser,
10591058
tool_parser=self.tool_call_parser,
1060-
splitwise_role=self.splitwise_role,
10611059
innode_prefill_ports=self.innode_prefill_ports,
10621060
max_num_partial_prefills=self.max_num_partial_prefills,
10631061
max_long_partial_prefills=self.max_long_partial_prefills,

fastdeploy/engine/common_engine.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def __init__(self, cfg, start_queue=True):
7676
cfg.scheduler_config.max_num_seqs,
7777
cfg,
7878
cfg.parallel_config.tensor_parallel_size,
79-
cfg.splitwise_role,
79+
cfg.scheduler_config.splitwise_role,
8080
cfg.parallel_config.local_data_parallel_id,
8181
)
82-
if cfg.splitwise_role != "mixed":
82+
if cfg.scheduler_config.splitwise_role != "mixed":
8383
raise NotImplementedError(
8484
"Currently ENABLE_V1_KVCACHE_SCHEDULER=1 only supported in mixed sampling now."
8585
)
@@ -88,13 +88,13 @@ def __init__(self, cfg, start_queue=True):
8888
cfg.scheduler_config.max_num_seqs,
8989
cfg,
9090
cfg.parallel_config.tensor_parallel_size,
91-
cfg.splitwise_role,
91+
cfg.scheduler_config.splitwise_role,
9292
cfg.parallel_config.local_data_parallel_id,
9393
)
9494

9595
self.start_worker_queue_service(start_queue)
9696

97-
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.engine_worker_queue_port[
97+
os.environ["INFERENCE_MSG_QUEUE_ID"] = self.cfg.parallel_config.engine_worker_queue_port[
9898
self.cfg.parallel_config.local_data_parallel_id
9999
]
100100

@@ -137,7 +137,9 @@ def start(self):
137137
self.token_processor.run()
138138

139139
def _init_worker_monitor_signals(self): # exist_task_signal 用于各worker进程感知是否有新Task需要处理
140-
current_suffix = int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id])
140+
current_suffix = int(
141+
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
142+
)
141143
llm_logger.info(f"current_suffix: {current_suffix}")
142144
exist_task_signal_data = np.zeros([1], dtype=np.int32)
143145
self.exist_task_signal = IPCSignal(
@@ -195,7 +197,7 @@ def start_worker_queue_service(self, start_queue):
195197
"""
196198
address = (
197199
self.cfg.master_ip,
198-
int(self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
200+
int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
199201
)
200202

201203
if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
@@ -209,7 +211,7 @@ def start_worker_queue_service(self, start_queue):
209211

210212
if (
211213
self.cfg.cache_config.enable_prefix_caching
212-
or self.cfg.splitwise_role != "mixed"
214+
or self.cfg.scheduler_config.splitwise_role != "mixed"
213215
and self.cfg.parallel_config.local_data_parallel_id == 0
214216
):
215217
self.cache_task_queue = EngineCacheQueue(
@@ -253,7 +255,10 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False):
253255
del self.resource_manager.req_dict[task.request_id]
254256
cur_task = self.resource_manager.tasks_list[cur_task_idx]
255257
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
256-
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.splitwise_role == "decode":
258+
if (
259+
self.cfg.speculative_config.method in ["mtp"]
260+
and self.cfg.scheduler_config.splitwise_role == "decode"
261+
):
257262
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
258263
if task.error_code != 200:
259264
self.resource_manager.stop_flags[cur_task_idx] = True
@@ -478,7 +483,10 @@ def _insert_task_to_worker(self):
478483
time.sleep(0.001)
479484
continue
480485
if hasattr(self, "exist_prefill_task_signal") and self.exist_prefill_task_signal.value[0] > 0:
481-
if self.cfg.splitwise_role == "mixed" or self.split_connector.has_splitwise_tasks():
486+
if (
487+
self.cfg.scheduler_config.splitwise_role == "mixed"
488+
or self.split_connector.has_splitwise_tasks()
489+
):
482490
time.sleep(0.005)
483491
continue
484492
if self.engine_worker_queue.num_cache_infos() > 0:
@@ -507,7 +515,7 @@ def _insert_task_to_worker(self):
507515
continue
508516

509517
current_id = (current_id + 1) % 100003
510-
if self.cfg.splitwise_role != "mixed":
518+
if self.cfg.scheduler_config.splitwise_role != "mixed":
511519
llm_logger.info("Inserting splitwise tasks")
512520
self.split_connector.send_splitwise_tasks(tasks, current_id)
513521

@@ -759,7 +767,7 @@ def start_cache_service(self, device_ids, ipc_signal_suffix):
759767
device_ids=device_ids,
760768
pod_ip=self.cfg.master_ip,
761769
engine_worker_queue_port=int(
762-
self.cfg.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
770+
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
763771
),
764772
pid_suffix=ipc_signal_suffix,
765773
)

fastdeploy/engine/engine.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def start(self, api_server_pid=None):
115115
start_time = time.time()
116116

117117
self.api_server_pid = api_server_pid
118-
self.ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
118+
self.ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
119119
self._init_worker_signals()
120120

121121
self.data_processor = self.input_processor.create_processor()
@@ -127,7 +127,7 @@ def start(self, api_server_pid=None):
127127
self.engine.start_zmq_service(api_server_pid)
128128

129129
if self.do_profile == 0 and (
130-
self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed"
130+
self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed"
131131
):
132132
device_ids = self.cfg.device_ids.split(",")
133133
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
@@ -161,7 +161,7 @@ def check_worker_initialize_status_func(res: dict):
161161
self._stop_profile()
162162
# Launch components: scheduler, cache_manager, expert_service et.al.
163163
self.launch_components()
164-
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
164+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
165165
self.launched_cache_manager_signal.value[0] = 1
166166

167167
# Worker launched
@@ -311,7 +311,7 @@ def _init_worker_signals(self):
311311
)
312312

313313
# launched_cache_manager_signal 用于感知engine是否启动了cache_manager
314-
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
314+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
315315
launched_cache_manager_signal_data = np.zeros([1], dtype=np.int32)
316316
self.launched_cache_manager_signal = IPCSignal(
317317
name="launched_cache_manager_signal",
@@ -426,10 +426,10 @@ def _setting_environ_variables(self):
426426
}
427427
)
428428

429-
if self.cfg.splitwise_role != "mixed":
429+
if self.cfg.scheduler_config.splitwise_role != "mixed":
430430
variables["FLAGS_use_pd_disaggregation"] = 1
431431
# TODO dynamic load environment variable
432-
if self.cfg.splitwise_role == "prefill":
432+
if self.cfg.scheduler_config.splitwise_role == "prefill":
433433
variables["FLAGS_fmt_write_cache_completed_signal"] = 1
434434

435435
if self.cfg.model_config.enable_mm:
@@ -463,7 +463,7 @@ def _start_worker_service(self):
463463
else len(self.data_processor.tokenizer.vocab)
464464
)
465465

466-
ports = ",".join(self.cfg.engine_worker_queue_port)
466+
ports = ",".join(self.cfg.parallel_config.engine_worker_queue_port)
467467
ips = None
468468
if self.cfg.ips is not None:
469469
ips = ",".join(self.cfg.ips)
@@ -481,9 +481,9 @@ def _start_worker_service(self):
481481
f" --enc_dec_block_num {self.cfg.cache_config.enc_dec_block_num}"
482482
f" --eos_tokens_lens {self.data_processor.eos_token_id_len}"
483483
f" --pad_token_id {self.data_processor.pad_token_id}"
484-
f" --engine_pid {self.cfg.engine_worker_queue_port[0]}"
484+
f" --engine_pid {self.cfg.parallel_config.engine_worker_queue_port[0]}"
485485
f" --max_num_batched_tokens {self.cfg.scheduler_config.max_num_batched_tokens}"
486-
f" --splitwise_role {self.cfg.splitwise_role}"
486+
f" --splitwise_role {self.cfg.scheduler_config.splitwise_role}"
487487
f" --kv_cache_ratio {self.cfg.cache_config.kv_cache_ratio}"
488488
f" --expert_parallel_size {self.cfg.parallel_config.expert_parallel_size}"
489489
f" --data_parallel_size {self.cfg.parallel_config.data_parallel_size}"
@@ -602,7 +602,7 @@ def _stop_profile(self):
602602
num_gpu_blocks = self.get_profile_block_num_signal.value[0]
603603
self.cfg.cache_config.reset(num_gpu_blocks)
604604
self.engine.resource_manager.reset_cache_config(self.cfg.cache_config)
605-
if self.cfg.cache_config.enable_prefix_caching or self.cfg.splitwise_role != "mixed":
605+
if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
606606
device_ids = self.cfg.device_ids.split(",")
607607
self.cache_manager_processes = self.engine.start_cache_service(device_ids, self.ipc_signal_suffix)
608608

@@ -619,7 +619,7 @@ def check_health(self, time_interval_threashold=30):
619619
return True, ""
620620

621621
def launch_components(self):
622-
if self.cfg.splitwise_role != "mixed":
622+
if self.cfg.scheduler_config.splitwise_role != "mixed":
623623
# 单机逻辑
624624
self.engine.engine_worker_queue.available_prefill_instances.put(1)
625625
self.engine.split_mode_get_tasks()
@@ -632,7 +632,7 @@ def launch_components(self):
632632

633633
self.cfg.init_cache_info()
634634

635-
role = self.cfg.splitwise_role
635+
role = self.cfg.scheduler_config.splitwise_role
636636
host_ip = self.cfg.host_ip
637637
disaggregate = self.cfg.disaggregate_info
638638
if self.cfg.scheduler_config.name == "splitwise":
@@ -649,7 +649,7 @@ def launch_components(self):
649649
):
650650
address = (
651651
self.cfg.master_ip,
652-
int(self.cfg.engine_worker_queue_port[i]),
652+
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
653653
)
654654
llm_logger.info(f"dp start queue service {address}")
655655
self.dp_engine_worker_queue_server.append(

fastdeploy/engine/expert_service.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def __init__(self, cfg, local_data_parallel_id, start_queue=True):
5050
self.cfg = cfg
5151
start_pos = (local_data_parallel_id * self.cfg.parallel_config.tensor_parallel_size) % cfg.worker_num_per_node
5252
end_pos = start_pos + self.cfg.parallel_config.tensor_parallel_size
53-
if cfg.splitwise_role != "mixed":
53+
if cfg.scheduler_config.splitwise_role != "mixed":
5454
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
5555
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
5656
llm_logger.info(f"local_data_parallel_id: {local_data_parallel_id}")
5757
self.cfg.disaggregate_info = None
5858

59-
if cfg.splitwise_role != "mixed":
59+
if cfg.scheduler_config.splitwise_role != "mixed":
6060
if len(self.cfg.cache_config.pd_comm_port) == 1:
6161
self.cfg.cache_config.pd_comm_port[0] = (
6262
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
@@ -84,21 +84,21 @@ def start(self, ipc_signal_suffix, local_data_parallel_id):
8484
self.api_server_pid = ipc_signal_suffix
8585
self.engine.start_zmq_service(ipc_signal_suffix)
8686
else:
87-
ipc_signal_suffix = self.cfg.engine_worker_queue_port[0]
87+
ipc_signal_suffix = self.cfg.parallel_config.engine_worker_queue_port[0]
8888

8989
llm_logger.info(f"start expert service {local_data_parallel_id}")
90-
if self.cfg.splitwise_role != "mixed":
90+
if self.cfg.scheduler_config.splitwise_role != "mixed":
9191
self.engine.start_cache_service(self.cfg.local_device_ids, ipc_signal_suffix)
9292
self.engine.split_mode_get_tasks()
9393

9494
if self.cfg.scheduler_config.name == "splitwise":
9595
self.cfg.init_cache_info()
96-
role = self.cfg.splitwise_role
96+
role = self.cfg.scheduler_config.splitwise_role
9797
host_ip = self.cfg.host_ip
9898
disaggregate = self.cfg.disaggregate_info
9999
self.engine.scheduler.start(role, host_ip, disaggregate)
100100

101-
if self.cfg.splitwise_role != "mixed":
101+
if self.cfg.scheduler_config.splitwise_role != "mixed":
102102
self.splitwise_receive_thread = threading.Thread(
103103
target=self.engine.split_connector.start_receiver, args=()
104104
)

0 commit comments

Comments
 (0)