@@ -186,6 +186,8 @@ def _apply(
186186 )
187187 shared_config .sharding_config .simple_shard_only = self .config .simple_shard_only
188188 shared_config .sharding_config .support_partial_config = self .config .support_partial_config
189+ shared_config .sharding_config .sharding_dims = self .config .sharding_dims
190+
189191 shared_config .sharding_config .use_sharding_from_factory = (
190192 self .config .use_sharding_from_factory
191193 )
@@ -201,8 +203,6 @@ def _apply(
201203 factory_info = detect_sharding_from_factory_config (gm , sharding_config )
202204 return gm , factory_info
203205
204- shared_config .sharding_config .sharding_dims = self .config .sharding_dims
205-
206206 ad_logger .info (
207207 f"Running autodeploy sharding heuristics: { shared_config .sharding_config .sharding_dims } "
208208 )
@@ -339,8 +339,39 @@ def detect_sharding_from_factory_config(
339339 # TODO: Sequence parallelism is not supported yet.
340340 ad_logger .warning ("Sequence parallelism is not supported yet. Skipping." )
341341 elif "local" in config :
342- # TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
343- ad_logger .warning ("Local EP+TP sharding is not supported yet. Skipping." )
342+ # Check if this applies to shared experts in EP parallelism.
343+ # If yes, apply the TP col-row shard.
344+ if "shared" in module_name :
345+ col_row_action = config .replace ("local_" , "" )
346+ if col_row_action == "colwise" :
347+ sharding_config .tp_transforms .append (
348+ TPShardingInfo (
349+ target_node = lin_node .name ,
350+ split_dim = SplitDimension .COLUMN ,
351+ rank = rank ,
352+ world_size = world_size ,
353+ dist_op = None ,
354+ min_local_shape = min_local_shape ,
355+ )
356+ )
357+ elif col_row_action == "rowwise" :
358+ sharding_config .tp_transforms .append (
359+ TPShardingInfo (
360+ target_node = lin_node .name ,
361+ split_dim = SplitDimension .ROW ,
362+ rank = rank ,
363+ world_size = world_size ,
364+ dist_op = "all_reduce" ,
365+ min_local_shape = min_local_shape ,
366+ )
367+ )
368+ num_row_col_shards += 1
369+ else :
370+ ad_logger .warning ("Invalid sharding config. Skipping." )
371+ else :
372+ # TODO: local refers to hybrid EP+TP parallelism. Not supported yet.
373+ ad_logger .warning ("Local EP+TP sharding is not supported yet. Skipping." )
374+
344375 elif "gather" in config :
345376 # Simple shard (row + all_gather)
346377 sharding_config .tp_transforms .append (
@@ -363,9 +394,35 @@ def detect_sharding_from_factory_config(
363394 f"Applied { num_shards } TP shards (simple: { num_simple_shards } , "
364395 f"row-col pattern: { num_row_col_shards } )"
365396 )
397+
398+ num_matches = len (sharding_config .tp_transforms )
399+
400+ if sharding_config .support_partial_config :
401+ ad_logger .info (
402+ f"Partial factory config applied only for TP. "
403+ f"Applying heuristics for { sharding_config .sharding_dims } ."
404+ )
405+
406+ # run EP sharding across ranks
407+ if "ep" in sharding_config .sharding_dims :
408+ ep_info = detect_ep_shard (gm , sharding_config )
409+ else :
410+ ep_info = TransformInfo (
411+ skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
412+ )
413+
414+ # run BMM sharding across ranks
415+ if "bmm" in sharding_config .sharding_dims :
416+ dp_bmm_info = detect_dp_bmm_shard (gm , sharding_config )
417+ else :
418+ dp_bmm_info = TransformInfo (
419+ skipped = True , num_matches = 0 , is_clean = True , has_valid_shapes = True
420+ )
421+ num_matches += ep_info .num_matches + dp_bmm_info .num_matches
422+
366423 return TransformInfo (
367424 skipped = False ,
368- num_matches = len ( sharding_config . tp_transforms ) ,
425+ num_matches = num_matches ,
369426 is_clean = False ,
370427 has_valid_shapes = False ,
371428 )
0 commit comments