Skip to content

Conversation

@tjruwase
Copy link

@tjruwase tjruwase commented Sep 6, 2021

Tools for converting checkpoints.

@tjruwase tjruwase requested a review from ShadenSmith September 6, 2021 15:13
@tjruwase
Copy link
Author

tjruwase commented Sep 6, 2021

@stas00 FYI

@stas00
Copy link

stas00 commented Sep 10, 2021

OK, to complete the conversion and make the model usable with "--finetune" the missing bits are:

sd["args"].tensor_model_parallel_size = 1
sd["args"].pipeline_model_parallel_size = 1
sd["args"].consumed_train_samples = 0
sd["args"].consumed_valid_samples = 0

the first 2 of course need to be adjusted to the target tp/pp sizes.

the last 2 need to be reset otherwise meg tries to resume training from some really high number of a sample which should be 0 instead.


I haven't quite figured out how to solve the padded_vocab_size being larger than the vocab. Probably needs to be truncated to the vocab size before saving the embeddings.

the workaround is to use the actual padded vocab size when finetuning, i.e.:

--make-vocab-size-divisible-by 50688 

for when the padded vocab ends up being 50688.

@stas00
Copy link

stas00 commented Sep 10, 2021

Then the files layout, clearly Meg-LM expects this layout:

iter_0000001/mp_rank_00_000/model_optim_rng.pt
latest_checkpointed_iteration.txt

whereas in the Meg-DS tree it wants:

iter_0000001/mp_rank_00/model_optim_rng.pt
latest_checkpointed_iteration.txt

no _000 at the end. I wonder if just creating one and adding a symlink to the other would do the trick. (this is with tp=1/pp=1) - at least this is how I'm overcoming this while testing with both trees.

The first segment of the path is:

directory = 'iter_{:07d}'.format(iteration)

in the meg code.

Additionally, we could probably convert global_step37876 to iter_0037876 to help the user know which iteration the training is coming from. Rather than iter=1, incidentally you can then save it in the dict as iteration=37876 which it wants anyway for non-finetune.

@tjruwase
Copy link
Author

the first 2 of course need to be adjusted to the target tp/pp sizes.

the last 2 need to be reset otherwise meg tries to resume training from some really high number of a sample which should be 0 instead.

Thanks for the feedback. I can update deepspeed_to_megatron.py for the first 2 to fix an inconsistency in the checkpoint state. However, I am unsure where to handle the last 2 since it would prevent the converted checkpoint from being used for continued training. So perhaps the finetuning script should handle the last 2. What do you think?

@stas00
Copy link

stas00 commented Sep 10, 2021

but this checkpoint can't be loaded for continued training at the moment. e.g. it lacks the iteration entry, and Meg crashes w/o --finetune because of that.

I'm not sure how you'd change the finetune script to ignore consumed_train_samples because it also checkpoints and should be able to resume from its consumed_train_samples ongoing record.

Perhaps we have 2 different modes here:

  1. reshape the checkpoint, but presume continued training
  2. convert for a release purpose, resetting some counters.

@stas00
Copy link

stas00 commented Sep 10, 2021

Found one more culprit - remember how Jared's script was asking for a megatron clone path when doing the conversion? It proved to be essential, since if we use the default Meg-DS when converting, it then fails to torch.load when attempting to use Meg-LM:

Traceback (most recent call last):
  File "pretrain_gpt.py", line 124, in <module>
    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
  File "/mnt/nvme1/code/huggingface/Megatron-LM/megatron/training.py", line 112, in pretrain
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  File "/mnt/nvme1/code/huggingface/Megatron-LM/megatron/training.py", line 325, in setup_model_and_optimizer
    args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
  File "/mnt/nvme1/code/huggingface/Megatron-LM/megatron/checkpointing.py", line 314, in load_checkpoint
    state_dict = torch.load(checkpoint_name, map_location='cpu')
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/home/stas/anaconda3/envs/py38-pt19/lib/python3.8/site-packages/torch/serialization.py", line 875, in find_class
    return super().find_class(mod_name, name)
ModuleNotFoundError: No module named 'megatron.enums'

Things have changed in Meg-DS and it now it can't find megatron.enums. In bigscience Meg-LM this is megatron.model.enums instead.

grr, this appears really tricky, now that the codebases are starting to diverge.

I can't even load the bigscience meg-ds checkpoint using Med-DS codebase in PYTHONPATH

OK solved this by adding both clones to PYTHONPATH explicitly:

PYTHONPATH=/hf/Megatron-DeepSpeed-master:/hf/Megatron-DeepSpeed-microsoft python tools/convert_checkpoint/deepspeed_to_megatron.py ...

but it still doesn't work when then I try to train with the Megatron-LM tree.

I'm trying to ask to restore that.

bigscience-workshop/Megatron-DeepSpeed#7 (comment)

Unless you have some bright ideas how to not to pickle structures that may be lacking in the target?

@tjruwase
Copy link
Author

tjruwase commented Sep 10, 2021

Additionally, we could probably convert global_step37876 to iter_0037876 to help the user know which iteration the training is coming from. Rather than iter=1, incidentally you can then save it in the dict as iteration=37876 which it wants anyway for non-finetune.

I am a bit confused by this. Are you seeing iter=1 in the converted checkpoint?

@stas00
Copy link

stas00 commented Sep 10, 2021

I think it perhaps expects the top-level iteration key and not args.iteration? remember how I mentioned there were 3 keys it expects? So I think it wants:

ITERATION_KEY = 'iteration'
...
    checkpoint_sd[ITERATION_KEY] = iteration

in your code

Are you seeing iter=1 in the converted checkpoint?

I see args.iteration=94743 but the checkpoint sub-dir is global_step97500 so it's odd. Must be a lack of sync between Meg and Meg-DS.

@tjruwase
Copy link
Author

think it perhaps expects the top-level iteration key and not args.iteration? remember how I mentioned there were 3 keys it expects? So I think it wants:

It would be great to get some clarity on which iteration to use here.

Also should it be

  1. python checkpoint_sd['args'][ITERATION_KEY] or
  2. python checkpoint_sd[ITERATION_KEY]?

@stas00
Copy link

stas00 commented Sep 10, 2021

think it perhaps expects the top-level iteration key and not args.iteration? remember how I mentioned there were 3 keys it expects? So I think it wants:

It would be great to get some clarity on which iteration to use here.

the problem is that it seems that Meg-DS doesn't update Meg's native iteration variable soon enough and so the saved checkpoint reports the outdated lower iteration. So most likely the fix is needed inside Meg-DS to at the very least sync Meg's native iteration variable with Meg-DS value of the same.

Also should it be

  1. python checkpoint_sd['args'][ITERATION_KEY] or
  2. python checkpoint_sd[ITERATION_KEY]?

The latter.

But it seems to be pointless, because once I add the missing key, it then fails with:

 loading checkpoint from /hf/Megatron-DeepSpeed-master/data/1B3-PP4-TP4-Meg at iteration 1
 checkpoint version 3.0
Unable to load optimizer from checkpoint /hf/Megatron-DeepSpeed-master/data/1B3-PP4-TP4-Meg/iter_0000001/mp_rank_00/model_optim_rng.pt. Specify --no-load-optim or --finetune to prevent attempting to load the optimizer state, exiting ...

we can't manifest optimizer states for Meg-LM out of nowhere, so it appears that after the conversion only inference or finetuning is possible. In which case it's probably pointless to try to keep consumed_train_samples and its friend.

Perhaps let's for now handle just the clear case of inference/finetuning with an assumption that finetuning will require a different dataset?

I think it's only when we reshape the checkpoint as discussed a few days later to support changing the degree of TP, is that when we would try to preserve everything, but that's when saving it from Meg-DS back to Meg-DS.

@stas00
Copy link

stas00 commented Sep 10, 2021

So I think the only remaining thing to address (other than embeddings) is: #14 (comment)

And let's set:

checkpoint_sd[ITERATION_KEY]=xxx

to whatever global_stepxxx points to. and match iter_0000xxx to it.

sorry, brainstorming here... but then what if the input checkpoint isn't named /path/to/global_stepxxx?

@stas00
Copy link

stas00 commented Sep 10, 2021

In another checkpoint I was given the discrepancy is 2x,

Meg-DS file is global_step37876 but args.iteration=18931 which is ~0.5 of the former. Weird. Looks like the 2 counters aren't in sync at all.

This tells me that args.iteration=18931 is incorrect and somehow the DS integration forgets to update it, and that global_step37876 is the correct iteration since it matches the log files.

@stas00
Copy link

stas00 commented Sep 11, 2021

OK, I have figured this one out. You were getting the wrong iteration, you need this one:

f = "global_step37876/mp_rank_00_model_states.pt"
sd = torch.load(f)
sd["iteration"]

this is the real iteration, but the original code args.iteration is whatever was the iteration at the start of training. Does it make sense?

sd["iteration"] # right
sd["args"].iteration # wrong

@stas00
Copy link

stas00 commented Sep 14, 2021

@tjruwase, 2 more things that I discovered are different from the checkpoint generated by Meg-LM natively. So these need to be changed as demonstrated:

- embeddings["word_embeddings.weight"]
+ embeddings["word_embeddings"]["weight"]

and:

+ embeddings["position_embeddings.weight"]
- embeddings["position_embeddings"]["weight"]

So I think after this fix, the resulting checkpoint will be matching the native one.

Plus the resulting file structure: #14 (comment)

and then it's good to be merged.

I did multiple tests on the final stage meg2hf and the conversion appears to be correct.

Iteration folder
latest checkpoint version file
@stas00
Copy link

stas00 commented Sep 18, 2021

This is good to merge, @tjruwase! Thank you!

@stas00
Copy link

stas00 commented Sep 20, 2021

@tjruwase, I made a small change to your work to separate the creation of the checkpoint and saving it, so that I could re-use it to create the HF transformers checkpoint on the fly.

Also made the PP/TP size 1 by default, since in the HF case it's always that for now.

If it looks acceptable to you perhaps let's merge this back into your master tree? If so please cherry-pick these 2:

  1. bigscience-workshop/Megatron-DeepSpeed@bea5ded
  2. bigscience-workshop/Megatron-DeepSpeed@d6c2a80

Thank you!

@stas00
Copy link

stas00 commented Sep 22, 2021

here you go - I added 2 more where I made the scripts executable ;)

git checkout -b convert-meg-ds-to-hf
git remote add other https://github.com/bigscience-workshop/Megatron-DeepSpeed
git fetch other
git cherry-pick bea5ded
git cherry-pick d6c2a80
git cherry-pick 26f18b5
git cherry-pick 2f662e8
git push --set-upstream origin convert-meg-ds-to-hf

jeffra pushed a commit that referenced this pull request Jan 18, 2022
saforem2 referenced this pull request in saforem2/Megatron-DeepSpeed Oct 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants