Skip to content

Commit

Permalink
fix conversion and eval (#6648)
Browse files Browse the repository at this point in the history
* fix conversion and eval

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed May 26, 2023
1 parent 83b77e6 commit cd15d97
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 35 deletions.
17 changes: 16 additions & 1 deletion examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@

import torch
from megatron.core import parallel_state
from omegaconf import open_dict
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model
from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel
Expand Down Expand Up @@ -80,7 +82,11 @@ def get_args():
help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.",
)
parser.add_argument(
"--model_type", type=str, required=True, default="gpt", choices=["gpt", "t5", "bert", "nmt", "bart", "retro"]
"--model_type",
type=str,
required=True,
default="gpt",
choices=["gpt", "sft", "t5", "bert", "nmt", "bart", "retro"],
)
parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1))
parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform")
Expand Down Expand Up @@ -138,6 +144,15 @@ def convert(local_rank, rank, world_size, args):

if args.model_type == 'gpt':
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
elif args.model_type == 'sft':
model = MegatronGPTSFTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
)
# we force the target for the loaded model to have the correct target
# because the hparams.yaml sometimes contains MegatronGPTModel as the target.
with open_dict(model.cfg):
model.cfg.target = f"{MegatronGPTSFTModel.__module__}.{MegatronGPTSFTModel.__name__}"

elif args.model_type == 'bert':
model = MegatronBertModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
Expand Down
84 changes: 50 additions & 34 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import json
import os

import torch.multiprocessing as mp
Expand All @@ -21,14 +22,9 @@
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from torch.utils.data import DataLoader

from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import (
MegatronGPTAdapterModel,
MegatronGPTAdapterPTuningModel,
MegatronGPTIA3Model,
MegatronGPTLoRAModel,
MegatronGPTPEFTModel,
MegatronGPTPTuningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import MegatronGPTPEFTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
Expand All @@ -42,27 +38,35 @@
from nemo.utils import logging

mp.set_start_method("spawn", force=True)

"""
This is the script to train an Adapter infused GPT Model for text generation.
A base GPT Model is required as a starting point. This script will then insert
Adapters into each Transformer layer and will train/update only these adapters
during training. The base GPT Model weights will remain frozen.
During training this script will only save the newly trained Adapter weights
in checkpoints. At the end of training a .nemo file of Adapter weights will
be saved.
Usage:
Assuming the base model is a 125m GPT Model, with TP=1, PP=1:
a. run a training run for a base gpt nemo file:
python megatron_gpt_adapter_tuning.py \
"model.data.train_ds=[PATH TO TRAINING JSONL FILE]",
"model.data.validation_ds=[PATH TO VALIDATION JSONL FILE]",
model.language_model_path="PATH TO BASE GPT MODEL .nemo FILE"
name="NAME OF TRAINING RUN"
exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE",
trainer.max_epochs=2
This is the script to run inference with a PEFT model or an SFT Model.
If you want to evaluate an SFT .nemo file:
python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=null \
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier
model.data.test_ds.global_batch_size=4 \ # or some other value
model.data.test_ds.micro_batch_size=4 \
model.data.test_ds.tokens_to_generate=30 \
inference.greedy=True \
inference.outfile_path=\'<path_to_jsonl_output_file>'
If you want to evaluate a PEFT Model, you should provide a base GPT model and a PEFT model .nemo file
python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=<path_to_peft_nemo_file> \ # this will be created if you use `megatron_gpt_peft_tuning.py`
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier
model.data.test_ds.global_batch_size=4 \ # or some other value
model.data.test_ds.micro_batch_size=4 \
model.data.test_ds.tokens_to_generate=30 \
inference.greedy=True \
inference.outfile_path=\'<path_to_jsonl_output_file>'
"""


Expand Down Expand Up @@ -105,7 +109,7 @@ def main(cfg) -> None:
restore_path=cfg.model.peft.restore_from_path, trainer=trainer, return_config=True,
)
else:
peft_model_cfg = MegatronGPTPEFTModel.restore_from(
peft_model_cfg = MegatronGPTSFTModel.restore_from(
restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True,
)

Expand All @@ -114,6 +118,8 @@ def main(cfg) -> None:
# update the model config of the trained model with params we want to set at inference time.
peft_model_cfg.precision = cfg.trainer.precision
peft_model_cfg.data.test_ds = cfg.model.data.test_ds
peft_model_cfg.activations_checkpoint_granularity = None
peft_model_cfg.activations_checkpoint_method = None

with open_dict(cfg):
# update the config with the trained model config
Expand All @@ -128,9 +134,8 @@ def main(cfg) -> None:
else:
save_restore_connector = NLPSaveRestoreConnector()

if os.path.isdir(peft_model_cfg.restore_from_path):
if os.path.isdir(cfg.model.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_from_path
# peft_cls = _get_peft_scheme(peft_model_cfg)
model = NLPModel.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
Expand All @@ -148,14 +153,25 @@ def main(cfg) -> None:
config = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(config)
response = trainer.predict(model, request_dl)

if model.global_rank == 0:
print("***************************")
if cfg.inference.outfile_path is not None:
with open(cfg.inference.outfile_path, "w", encoding="utf-8") as f:
for batch in response:
for sentence in batch["sentences"]:
s = " ".join(sentence.split("\n"))
f.write(s + "\n")
batch_sentences = [s for s in batch['sentences']]
batch_tokens = [s for s in batch['tokens']]
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
d = {'sentence': s}
f.write(json.dumps(d) + '\n')
print("predictions saved to {}".format(cfg.inference.outfile_path))
else:
print(response)
Expand Down

0 comments on commit cd15d97

Please sign in to comment.