Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalized chat sft prompt #7655

Merged
merged 30 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c66de49
fix dataset issues
yidong72 Oct 5, 2023
3e42b74
Merge branch 'main' into sft_mcore
yidong72 Oct 5, 2023
1f3d2d3
working version
yidong72 Oct 6, 2023
7fdc339
all passed
yidong72 Oct 6, 2023
87a01bb
refactor tests
yidong72 Oct 6, 2023
9c0ee5c
all pass
yidong72 Oct 6, 2023
ccaa6a0
working version
yidong72 Oct 6, 2023
14467c4
use end name signal for labels
yidong72 Oct 6, 2023
4a674e4
all fixed
yidong72 Oct 6, 2023
11bc6cd
update doc
yidong72 Oct 6, 2023
2e0285a
style fix
yidong72 Oct 6, 2023
4ba2395
remove unused imports
yidong72 Oct 6, 2023
d1b8328
make sure nccl not timing out
yidong72 Oct 6, 2023
5bf546e
style fix
yidong72 Oct 6, 2023
f945ec6
Merge branch 'main' into sft_mcore
yidong72 Oct 6, 2023
cd7c77a
generate example template
yidong72 Oct 6, 2023
d734830
generic end of name token
yidong72 Oct 6, 2023
33b7910
style fix
yidong72 Oct 6, 2023
e293336
Merge branch 'sft_mcore' of github.com:NVIDIA/NeMo into sft_mcore
yidong72 Oct 6, 2023
c99b55f
add the chat prompt format into the config
yidong72 Oct 6, 2023
b64f0bd
make sure sft working
yidong72 Oct 6, 2023
86bb7b0
address reviewer comment
yidong72 Oct 6, 2023
019afa4
Merge branch 'main' into sft_mcore
yidong72 Oct 6, 2023
3ddd9cd
fix non
yidong72 Oct 7, 2023
a1789e4
try openAI prompt
yidong72 Oct 7, 2023
4db2188
Merge branch 'main' into sft_mcore
yidong72 Oct 7, 2023
d36d3a9
remove unused imports
yidong72 Oct 7, 2023
162be79
remove human labels from the data
yidong72 Oct 9, 2023
700d9f2
use hf dataset to clean
yidong72 Oct 9, 2023
ed68643
reviewer comments
yidong72 Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import datetime
import os
import threading
from functools import partial
Expand Down Expand Up @@ -167,7 +168,11 @@ def remove_padded_prompts(response, nb_paddings):
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer, callbacks=[CustomProgressBar()])
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
**cfg.trainer,
callbacks=[CustomProgressBar()],
)

if cfg.gpt_model_file is not None:
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ model:

data:
chat: False # whether use chatbot data or not
chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '<im end><im start>', the '><' sometimes is merged to be a single token. This is not supported, try to avoid
Zhilin123 marked this conversation as resolved.
Show resolved Hide resolved
system_turn_start: '<extra_id_0>'
turn_start: '<extra_id_1>'
label_start: '<extra_id_2>'
end_of_turn: "\x0A" # \0x0A is '\n'
end_of_name: "\x0A" # \0x0A is '\n'
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
12 changes: 11 additions & 1 deletion examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import os
import tempfile

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import get_prompt_template_example
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
Expand All @@ -36,6 +37,8 @@
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank

mp.set_start_method("spawn", force=True)


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
Expand Down Expand Up @@ -71,6 +74,13 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

if cfg.model.data.get('chat', False):
# chat model, overwrite the prompt template
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template
gpt_cfg.data.test_ds.prompt_template = prompt_template

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

Expand Down
Loading
Loading