@@ -436,13 +436,13 @@ class TrainingArguments:
436436 deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.
437437
438438 A List of config and its options:
439- - fsdp_min_num_params (`int`, *optional*, defaults to `0`):
439+ - min_num_params (`int`, *optional*, defaults to `0`):
440440 FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
441441 passed).
442- - fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
442+ - transformer_layer_cls_to_wrap (`List[str]`, *optional*):
443443 List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
444444 `T5Block` .... (useful only when `fsdp` flag is passed).
445- - fsdp_backward_prefetch (`str`, *optional*)
445+ - backward_prefetch (`str`, *optional*)
446446 FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
447447 `fsdp` field is passed).
448448
@@ -454,14 +454,22 @@ class TrainingArguments:
454454 - `"backward_post"` : This prefetches the next set of parameters after the current set of
455455 parameter’s
456456 gradient computation.
457- - fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
457+ - forward_prefetch (`bool`, *optional*, defaults to `False`)
458458 FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
459459 If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
460460 forward pass.
461461 - limit_all_gathers (`bool`, *optional*, defaults to `False`)
462462 FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
463463 If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
464464 all-gathers.
465+ - use_orig_params (`bool`, *optional*, defaults to `False`)
466+ If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
467+ frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
468+ refer this
469+ [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
470+ - sync_module_states (`bool`, *optional*, defaults to `True`)
471+ If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
472+ ensure they are the same across all ranks after initialization
465473 - xla (`bool`, *optional*, defaults to `False`):
466474 Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
467475 and its API may evolve in the future.
@@ -1520,44 +1528,44 @@ def __post_init__(self):
15201528 self .fsdp_config = {}
15211529
15221530 if isinstance (self .fsdp_config , str ):
1531+ if len (self .fsdp ) == 0 :
1532+ warnings .warn ("`--fsdp_config` is useful only when `--fsdp` is specified." )
15231533 with io .open (self .fsdp_config , "r" , encoding = "utf-8" ) as f :
15241534 self .fsdp_config = json .load (f )
1535+ for k , v in self .fsdp_config .items ():
1536+ if k .startswith ("fsdp_" ):
1537+ self .fsdp_config [k .replace ("fsdp_" , "" )] = v
1538+ del self .fsdp_config [k ]
15251539
15261540 if self .fsdp_min_num_params > 0 :
15271541 warnings .warn ("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead " , FutureWarning )
15281542
1529- self .fsdp_config ["fsdp_min_num_params" ] = max (
1530- self .fsdp_config .get ("fsdp_min_num_params" , 0 ), self .fsdp_min_num_params
1531- )
1543+ self .fsdp_config ["min_num_params" ] = max (self .fsdp_config .get ("min_num_params" , 0 ), self .fsdp_min_num_params )
15321544
1533- # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
1534- if isinstance (self .fsdp_config .get ("fsdp_transformer_layer_cls_to_wrap" , None ), str ):
1535- self .fsdp_config ["fsdp_transformer_layer_cls_to_wrap" ] = [
1536- self .fsdp_config ["fsdp_transformer_layer_cls_to_wrap" ]
1537- ]
1545+ # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
1546+ if isinstance (self .fsdp_config .get ("transformer_layer_cls_to_wrap" , None ), str ):
1547+ self .fsdp_config ["transformer_layer_cls_to_wrap" ] = [self .fsdp_config ["transformer_layer_cls_to_wrap" ]]
15381548
15391549 if self .fsdp_transformer_layer_cls_to_wrap is not None :
15401550 warnings .warn (
15411551 "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead " , FutureWarning
15421552 )
1543- self .fsdp_config ["fsdp_transformer_layer_cls_to_wrap " ] = self .fsdp_config .get (
1544- "fsdp_transformer_layer_cls_to_wrap " , []
1553+ self .fsdp_config ["transformer_layer_cls_to_wrap " ] = self .fsdp_config .get (
1554+ "transformer_layer_cls_to_wrap " , []
15451555 ) + [self .fsdp_transformer_layer_cls_to_wrap ]
15461556
1547- if len (self .fsdp ) == 0 and self .fsdp_config ["fsdp_min_num_params " ] > 0 :
1548- warnings .warn ("`--fsdp_min_num_params ` is useful only when `--fsdp` is specified." )
1557+ if len (self .fsdp ) == 0 and self .fsdp_config ["min_num_params " ] > 0 :
1558+ warnings .warn ("`min_num_params ` is useful only when `--fsdp` is specified." )
15491559
1550- if len (self .fsdp ) == 0 and self .fsdp_config .get ("fsdp_transformer_layer_cls_to_wrap " , None ) is not None :
1551- warnings .warn ("`--fsdp_transformer_layer_cls_to_wrap ` is useful only when `--fsdp` is specified." )
1560+ if len (self .fsdp ) == 0 and self .fsdp_config .get ("transformer_layer_cls_to_wrap " , None ) is not None :
1561+ warnings .warn ("`transformer_layer_cls_to_wrap ` is useful only when `--fsdp` is specified." )
15521562
15531563 if (
15541564 len (self .fsdp ) > 0
1555- and self .fsdp_config ["fsdp_min_num_params " ] > 0
1556- and self .fsdp_config .get ("fsdp_transformer_layer_cls_to_wrap " , None ) is not None
1565+ and self .fsdp_config ["min_num_params " ] > 0
1566+ and self .fsdp_config .get ("transformer_layer_cls_to_wrap " , None ) is not None
15571567 ):
1558- raise ValueError (
1559- "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
1560- )
1568+ raise ValueError ("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive." )
15611569 self .fsdp_config ["xla" ] = self .fsdp_config .get ("xla" , False )
15621570 self .fsdp_config ["xla_fsdp_grad_ckpt" ] = self .fsdp_config .get ("xla_fsdp_grad_ckpt" , False )
15631571 if self .fsdp_config ["xla" ]:
@@ -1583,23 +1591,29 @@ def __post_init__(self):
15831591 FSDP_SHARDING_STRATEGY ,
15841592 )
15851593
1594+ prefix = "FSDP_"
15861595 for fsdp_option in self .fsdp :
15871596 if fsdp_option .upper () in FSDP_SHARDING_STRATEGY :
15881597 # set environment variable for FSDP sharding strategy
1589- os .environ ["FSDP_SHARDING_STRATEGY" ] = str (FSDP_SHARDING_STRATEGY .index (fsdp_option .upper ()) + 1 )
1598+ os .environ [f"{ prefix } SHARDING_STRATEGY" ] = str (
1599+ FSDP_SHARDING_STRATEGY .index (fsdp_option .upper ()) + 1
1600+ )
15901601 elif fsdp_option == FSDPOption .OFFLOAD :
1591- os .environ ["FSDP_OFFLOAD_PARAMS " ] = "true"
1602+ os .environ [f" { prefix } OFFLOAD_PARAMS " ] = "true"
15921603 elif fsdp_option == FSDPOption .AUTO_WRAP :
1593- os .environ ["FSDP_AUTO_WRAP_POLICY " ] = FSDP_AUTO_WRAP_POLICY [0 ]
1594- if self .fsdp_config ["fsdp_min_num_params " ] > 0 :
1595- os .environ ["FSDP_MIN_NUM_PARAMS " ] = str (self .fsdp_config ["fsdp_min_num_params " ])
1596- os .environ ["FSDP_AUTO_WRAP_POLICY " ] = FSDP_AUTO_WRAP_POLICY [1 ]
1597- elif self .fsdp_config .get ("fsdp_transformer_layer_cls_to_wrap " , None ) is not None :
1598- os .environ ["FSDP_TRANSFORMER_CLS_TO_WRAP " ] = "," .join (
1599- self .fsdp_config ["fsdp_transformer_layer_cls_to_wrap " ]
1604+ os .environ [f" { prefix } AUTO_WRAP_POLICY " ] = FSDP_AUTO_WRAP_POLICY [0 ]
1605+ if self .fsdp_config ["min_num_params " ] > 0 :
1606+ os .environ [f" { prefix } MIN_NUM_PARAMS " ] = str (self .fsdp_config ["min_num_params " ])
1607+ os .environ [f" { prefix } AUTO_WRAP_POLICY " ] = FSDP_AUTO_WRAP_POLICY [1 ]
1608+ elif self .fsdp_config .get ("transformer_layer_cls_to_wrap " , None ) is not None :
1609+ os .environ [f" { prefix } TRANSFORMER_CLS_TO_WRAP " ] = "," .join (
1610+ self .fsdp_config ["transformer_layer_cls_to_wrap " ]
16001611 )
16011612 prefetch_policy = self .fsdp_config .get ("fsdp_backward_prefetch" , "NO_PREFETCH" )
1602- os .environ ["FSDP_BACKWARD_PREFETCH" ] = prefetch_policy .upper ()
1613+ os .environ [f"{ prefix } BACKWARD_PREFETCH" ] = prefetch_policy .upper ()
1614+ os .environ [f"{ prefix } FORWARD_PREFETCH" ] = self .fsdp_config .get ("forward_prefect" , "false" )
1615+ os .environ [f"{ prefix } SYNC_MODULE_STATES" ] = self .fsdp_config .get ("sync_module_states" , "true" )
1616+ os .environ [f"{ prefix } USE_ORIG_PARAMS" ] = self .fsdp_config .get ("use_orig_params" , "false" )
16031617
16041618 if self .tpu_metrics_debug :
16051619 warnings .warn (
0 commit comments