Skip to content

Commit

Permalink
Rename Finetuning Scripts (#8201)
Browse files Browse the repository at this point in the history
* rename all scripts

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix CodeQL error

Signed-off-by: Chen Cui <[email protected]>

* rename finetune_generate to just generate

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
2 people authored and stevehuang52 committed Jan 31, 2024
1 parent 41c76c4 commit e1e2786
Show file tree
Hide file tree
Showing 16 changed files with 379 additions and 37 deletions.
16 changes: 8 additions & 8 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3949,7 +3949,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
}
failFast true
steps {
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
Expand Down Expand Up @@ -3978,7 +3978,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=1 \
Expand Down Expand Up @@ -4054,7 +4054,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
failFast true
steps {
sh "rm -rf examples/nlp/language_modeling/gpt_peft_lora_results_pp2"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.max_epochs=9999 \
Expand Down Expand Up @@ -4089,7 +4089,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
failFast true
steps {
sh "rm -rf /home/TestData/nlp/lora_tuning_tp2"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.max_epochs=9999 \
Expand All @@ -4111,7 +4111,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/TP2/megatron_gpt_tp2.nemo \
model.peft.restore_from_path=/home/TestData/nlp/lora_tuning_tp2/megatron_gpt_peft_lora_tuning/checkpoints/megatron_gpt_peft_lora_tuning.nemo \
model.peft.restore_from_ckpt_name=null \
Expand Down Expand Up @@ -4176,7 +4176,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
}
failFast true
steps{
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=/home/TestData/nlp/megatron_gpt_sft/megatron_gpt_rope_sft.nemo \
model.peft.restore_from_path=null \
model.data.test_ds.file_names=['/home/TestData/nlp/megatron_gpt_sft/sample.jsonl'] \
Expand Down Expand Up @@ -4995,7 +4995,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
failFast true
steps {
sh "rm -rf /home/TestData/nlp/t5_lora_tuning_tp2"
sh "python examples/nlp/language_modeling/tuning/megatron_t5_peft_tuning.py \
sh "python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.max_epochs=9999 \
Expand All @@ -5017,7 +5017,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]"
sh "python examples/nlp/language_modeling/tuning/megatron_t5_peft_eval.py \
sh "python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \
model.restore_from_path=/home/TestData/nlp/megatron_t5/8m/megatron_t5_8m_tp2.nemo \
model.peft.restore_from_path=/home/TestData/nlp/t5_lora_tuning_tp2/megatron_t5_peft_lora_tuning/checkpoints/megatron_t5_peft_lora_tuning.nemo \
model.peft.restore_from_ckpt_name=null \
Expand Down
File renamed without changes.
File renamed without changes.
81 changes: 81 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf

from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP

from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)

"""
This is the script to finetuning a GPT Model with any PEFT method.
A base GPT Model is required as a starting point. This script will then insert
Adapters into each Transformer layer and will train/update only these adapters
during training. The base GPT Model weights will remain frozen.
During training this script will only save the newly trained Adapter weights
in checkpoints. At the end of training a .nemo file of Adapter weights will
be saved.
Usage:
Assuming the base model is a 125m GPT Model, with TP=1, PP=1:
a. run a training run for a base gpt nemo file:
python megatron_gpt_finetuning.py \
"model.data.train_ds.file_names=[PATH TO TRAINING JSONL FILE]",
"model.data.train_ds.concat_sampling_probabilities=[SAMPLING VAL]",
"model.data.validation_ds.file_names=[PATH TO VALIDATION JSONL FILE]",
"model.data.validation_ds.names=[NAME FOR METRIC LOGGING]",
model.restore_from_path="PATH TO BASE GPT MODEL .nemo FILE"
model.peft.peft_scheme='lora' # lora, ptuning, adapter, ia3, or none for full fineutning
name="NAME OF TRAINING RUN"
exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE",
Please see lora.ipynb for a step-by-step guide.
"""


@hydra_runner(config_path="conf", config_name="megatron_gpt_finetuning_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg)
model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer)
peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path)
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg))
elif peft_cfg_cls is not None:
logging.info("Adding adapter weights to the model for PEFT")
model.add_adapter(peft_cfg_cls(model_cfg))
else:
logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}")

trainer.fit(model)


if __name__ == '__main__':
main()
143 changes: 143 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
import threading
from functools import partial

import torch
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf


from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging

try:
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False

mp.set_start_method("spawn", force=True)
"""
This is the script to run inference with a PEFT model or an SFT Model.
If you want to evaluate an SFT .nemo file:
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=null \
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier
model.data.test_ds.global_batch_size=4 \ # or some other value
model.data.test_ds.micro_batch_size=4 \
model.data.test_ds.tokens_to_generate=30 \
inference.greedy=True \
inference.outfile_path=\'<path_to_jsonl_output_file>'
If you want to evaluate a PEFT Model, you should provide a base GPT model and a PEFT model .nemo file
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=<path_to_peft_nemo_file> \ # this will be created if you use `megatron_gpt_finetuning.py`
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier
model.data.test_ds.global_batch_size=4 \ # or some other value
model.data.test_ds.micro_batch_size=4 \
model.data.test_ds.tokens_to_generate=30 \
inference.greedy=True \
inference.outfile_path=\'<path_to_jsonl_output_file>'
"""


def use_inference_server(cfg, model, trainer):
if not HAVE_MEGATRON_CORE:
raise ValueError('Megatron-core needs to be installed to use this feature!')

from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo

trainer.test(model, dataloaders=None)

if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
if cfg.chat:
defaults = {
'user': cfg.chatbot_config.user,
'assistant': cfg.chatbot_config.assistant,
'system': cfg.chatbot_config.system,
}
web_ui = partial(
get_chatbot_demo,
defaults=defaults,
value=cfg.chatbot_config.value,
attributes=cfg.chatbot_config.attributes,
)
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
thread.start()
server = MegatronServer(model.cuda())
server.run("0.0.0.0", port=cfg.port)

while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
generate(model.cuda())


@hydra_runner(config_path="conf", config_name="megatron_gpt_generate_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")
trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()

if cfg.model.peft.restore_from_path:
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg)
else:
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg)

model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer)

if cfg.model.peft.restore_from_path:
model.load_adapters(cfg.model.peft.restore_from_path)

model.freeze()
logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}")

if not cfg.model.get('use_flash_attention', False):
cfg.inference.compute_attention_mask = True
config = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(config)

if not cfg.server:
trainer.test(model)
else:
use_inference_server(cfg, model, trainer)


if __name__ == "__main__":
main()
21 changes: 16 additions & 5 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#############################
# THIS SCRIPT IS DEPRECATED #
#############################

import asyncio
import threading
Expand All @@ -28,6 +30,7 @@
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.decorators import deprecated

try:
from megatron.core import parallel_state
Expand All @@ -42,7 +45,7 @@
If you want to evaluate an SFT .nemo file:
python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=null \
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
Expand All @@ -55,9 +58,9 @@
If you want to evaluate a PEFT Model, you should provide a base GPT model and a PEFT model .nemo file
python examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py \
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=<path_to_sft_nemo_file> \
model.peft.restore_from_path=<path_to_peft_nemo_file> \ # this will be created if you use `megatron_gpt_peft_tuning.py`
model.peft.restore_from_path=<path_to_peft_nemo_file> \ # this will be created if you use `megatron_gpt_finetuning.py`
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier
model.data.test_ds.global_batch_size=4 \ # or some other value
Expand Down Expand Up @@ -108,7 +111,15 @@ def use_inference_server(cfg, model, trainer):
generate(model.cuda())


@hydra_runner(config_path="conf", config_name="megatron_gpt_peft_eval_config")
banner = '\n'.join(['' "*" * 80] * 5)


@deprecated(
wait_seconds=20,
explanation=f"\n{banner}\nmegatron_gpt_peft_eval.py is renamed to megatron_gpt_generate.py with the "
f"same functionality. \nPlease switch to the new name.\n{banner}\n",
)
@hydra_runner(config_path="conf", config_name="megatron_gpt_generate_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")
Expand Down
16 changes: 13 additions & 3 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#############################
# THIS SCRIPT IS DEPRECATED #
#############################
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf

Expand All @@ -21,6 +23,7 @@

from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.decorators import deprecated
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)
Expand All @@ -38,7 +41,7 @@
Usage:
Assuming the base model is a 125m GPT Model, with TP=1, PP=1:
a. run a training run for a base gpt nemo file:
python megatron_gpt_peft_tuning.py \
python megatron_gpt_finetuning.py \
"model.data.train_ds.file_names=[PATH TO TRAINING JSONL FILE]",
"model.data.train_ds.concat_sampling_probabilities=[SAMPLING VAL]",
"model.data.validation_ds.file_names=[PATH TO VALIDATION JSONL FILE]",
Expand All @@ -50,8 +53,15 @@
Please see lora.ipynb for a step-by-step guide.
"""

banner = '\n'.join(['' "*" * 80] * 5)


@hydra_runner(config_path="conf", config_name="megatron_gpt_peft_tuning_config")
@deprecated(
wait_seconds=20,
explanation=f"\n{banner}\nmegatron_gpt_peft_tuning.py is renamed to megatron_gpt_finetuning.py with the "
f"same functionality. \nPlease switch to the new name.\n{banner}\n",
)
@hydra_runner(config_path="conf", config_name="megatron_gpt_finetuning_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
Expand Down
Loading

0 comments on commit e1e2786

Please sign in to comment.