@@ -340,11 +340,14 @@ def event_loop_normal(self) -> None:
340340 mmap_infos = create_mmap (
341341 [MODEL_MAIN_NAME ], self .local_rank , self .ranks , shm_uuid = os .getenv ("SHM_UUID" , "" ), logger = logger
342342 )
343+
344+ tp_size = self .parallel_config .tensor_parallel_size
343345 # Currently, only support single node
344- self .nnode = int ((self . parallel_config . tensor_parallel_size + 7 ) // 8 )
346+ self .nnode = int ((tp_size + 7 ) // 8 )
345347 req_ids = []
346348 num_running_requests = 0
347- local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
349+ tp_rank = self .local_rank % tp_size
350+
348351 self .model_weights_signal = np .zeros ([1 ], dtype = np .int32 )
349352 while True :
350353 if self .eplb_config .enable_redundant_experts :
@@ -385,35 +388,34 @@ def event_loop_normal(self) -> None:
385388 if self .local_rank == 0 :
386389 rearrange_experts_status_array [0 ] = RearrangeExpertState .done .value
387390 logger .info ("redundant_expert: done" )
388- if self . local_rank % self . parallel_config . tensor_parallel_size == 0 :
391+ if tp_rank == 0 :
389392 if self .model_weights_status .value [0 ] != ModelWeightsStatus .NORMAL :
390393 self .model_weights_signal [0 ] = int (self .model_weights_status .value [0 ])
391394 if self .fd_config .load_config .dynamic_load_weight and self .parallel_config .enable_expert_parallel :
392395 self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
393396 src = 0 , group = self .parallel_config .ep_group
394397 )
395- if self .fd_config .load_config .dynamic_load_weight and self . parallel_config . tensor_parallel_size > 1 :
398+ if self .fd_config .load_config .dynamic_load_weight and tp_size > 1 :
396399 self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
397400 src = 0 , group = self .parallel_config .tp_group
398401 )
399402
400403 self .insert_step = False
401404 req_dicts = None
402- local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
403- self .worker_healthy_live_signal .value [local_rank % self .max_chips_per_node ] = int (time .time ())
405+ self .worker_healthy_live_signal .value [tp_rank % self .max_chips_per_node ] = int (time .time ())
404406
405407 # The first worker detects whether there are tasks in the task queue
406- if local_rank == 0 :
408+ if tp_rank == 0 :
407409 if self .task_queue .num_tasks () > 0 :
408410 if envs .ENABLE_V1_KVCACHE_SCHEDULER or not (
409411 self .fd_config .model_config .enable_mm and self .worker .exist_prefill ()
410412 ):
411- if self .nnode > 1 and self . parallel_config . tensor_parallel_size > self .max_chips_per_node :
413+ if self .nnode > 1 and tp_size > self .max_chips_per_node :
412414 self .task_queue .read_finish_flag .set (1 )
413415 else :
414416 self .exist_task_signal .value [0 ] = ExistTaskStatus .EXIST
415417
416- if self . parallel_config . tensor_parallel_size > 1 :
418+ if tp_size > 1 :
417419 # Synchronize the signal for other workers
418420 self ._tp_barrier_wait ()
419421
0 commit comments