Skip to content

Commit

Permalink
fix bug when loading dist ckpt in peft
Browse files Browse the repository at this point in the history
Signed-off-by: Hongbin Liu <[email protected]>
  • Loading branch information
Hongbin Liu committed Sep 18, 2023
1 parent 33f5b9f commit 4720696
Showing 1 changed file with 57 additions and 8 deletions.
65 changes: 57 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,8 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None):
peft_state_dict = torch.load(model_weights_path, map_location)['state_dict']
else:
peft_state_dict = {}
base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict
if base_model_state_dict:
base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict
return base_model_state_dict

def restore_from(
Expand All @@ -765,13 +766,61 @@ def restore_from(
return loaded_params
conf, instance, state_dict = loaded_params

if (
self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None
): # we have this check only for training PEFT from scratch
peft_state_dict = instance.get_peft_state_dict()
state_dict.update(peft_state_dict)
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)
# if we're using dist checkpointing then state_dict will be None
if state_dict is None:
# dist checkpointing needs torch.distributed to load the checkpoint
if parallel_state.is_unitialized():

def dummy():
return

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()

with tempfile.TemporaryDirectory() as tmpdir:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that NeMo will use the provided `model_extracted_dir`
logging.info(
f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`."
)

# Override `tmpdir` above with the pre-extracted `model_extracted_dir`
tmpdir = self.model_extracted_dir

else:
# Extract the nemo file into the temporary directory
self._unpack_nemo_file(
path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True
)
checkpoint = {}
sharded_state_dict = instance.sharded_state_dict()
peft_state_dict = instance.get_peft_state_dict()
for k in peft_state_dict.keys():
sharded_state_dict.pop(k)
checkpoint['state_dict'] = sharded_state_dict
# remove model weights extension
tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
checkpoint = dist_checkpointing.load(
sharded_state_dict=checkpoint, checkpoint_dir=tmp_model_weights_dir
)
checkpoint['state_dict'].update(peft_state_dict)
instance.on_load_checkpoint(checkpoint)
if hasattr(instance, 'setup_transformer_engine_tp_groups'):
instance.setup_transformer_engine_tp_groups()

else:
if (
self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None
): # we have this check only for training PEFT from scratch
peft_state_dict = instance.get_peft_state_dict()
state_dict.update(peft_state_dict)
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)

logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
return instance

Expand Down

0 comments on commit 4720696

Please sign in to comment.