2929 choose_logger ,
3030 chunked_cross_entropy ,
3131 copy_config_files ,
32+ get_default_supported_precision ,
3233 init_out_dir ,
3334 num_parameters ,
3435 parse_devices ,
@@ -42,6 +43,7 @@ def setup(
4243 model_name : Optional [str ] = None ,
4344 model_config : Optional [Config ] = None ,
4445 out_dir : Path = Path ("out/pretrain" ),
46+ precision : Literal ["bf16-true" , "bf16-mixed" , "32-true" , None ] = None ,
4547 initial_checkpoint_dir : Optional [Path ] = None ,
4648 resume : Union [bool , Path ] = False ,
4749 data : Optional [DataModule ] = None ,
@@ -75,6 +77,7 @@ def setup(
7577 ``model_config``.
7678 out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
7779 /teamspace/jobs/<job-name>/share.
80+ precision: The precision to use for finetuning. Determines a compatible precision setting by default.
7881 initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
7982 Useful for continued pretraining. Mutually exclusive with ``resume``.
8083 resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
@@ -96,6 +99,7 @@ def setup(
9699 available_models = "\n " .join (sorted (name_to_config ))
97100 raise ValueError (f"Please specify --model_name <model_name>. Available values:\n { available_models } " )
98101 config = Config .from_name (model_name ) if model_config is None else model_config
102+ precision = precision or get_default_supported_precision (training = True )
99103 devices = parse_devices (devices )
100104 out_dir = init_out_dir (out_dir )
101105 # in case the dataset requires the Tokenizer
@@ -109,7 +113,7 @@ def setup(
109113 strategy = FSDPStrategy (auto_wrap_policy = {Block }, state_dict_type = "full" , sharding_strategy = "HYBRID_SHARD" )
110114 else :
111115 strategy = "auto"
112- fabric = L .Fabric (devices = devices , strategy = strategy , precision = "bf16-mixed" , loggers = [logger ])
116+ fabric = L .Fabric (devices = devices , strategy = strategy , precision = precision , loggers = [logger ])
113117 fabric .launch ()
114118
115119 fabric .print (pprint .pformat (hparams ))
@@ -169,12 +173,13 @@ def main(
169173
170174 model = torch .compile (model )
171175 model = fabric .setup (model )
176+
172177 optimizer = torch .optim .AdamW (
173178 model .parameters (),
174179 lr = train .learning_rate ,
175180 weight_decay = train .weight_decay ,
176181 betas = (train .beta1 , train .beta2 ),
177- fused = True ,
182+ fused = fabric . device . type == "cuda" ,
178183 )
179184 optimizer = fabric .setup_optimizers (optimizer )
180185
0 commit comments