Skip to content

Commit

Permalink
generalized chat sft prompt (#7655)
Browse files Browse the repository at this point in the history
* fix dataset issues

Signed-off-by: Yi Dong <[email protected]>

* working version

Signed-off-by: Yi Dong <[email protected]>

* all passed

Signed-off-by: Yi Dong <[email protected]>

* refactor tests

Signed-off-by: Yi Dong <[email protected]>

* all pass

Signed-off-by: Yi Dong <[email protected]>

* working version

Signed-off-by: Yi Dong <[email protected]>

* use end name signal for labels

Signed-off-by: Yi Dong <[email protected]>

* all fixed

Signed-off-by: Yi Dong <[email protected]>

* update doc

Signed-off-by: Yi Dong <[email protected]>

* style fix

Signed-off-by: Yi Dong <[email protected]>

* remove unused imports

Signed-off-by: Yi Dong <[email protected]>

* make sure nccl not timing out

Signed-off-by: Yi Dong <[email protected]>

* style fix

Signed-off-by: Yi Dong <[email protected]>

* generate example template

Signed-off-by: Yi Dong <[email protected]>

* generic end of name token

Signed-off-by: Yi Dong <[email protected]>

* style fix

Signed-off-by: Yi Dong <[email protected]>

* add the chat prompt format into the config

Signed-off-by: Yi Dong <[email protected]>

* make sure sft working

Signed-off-by: Yi Dong <[email protected]>

* address reviewer comment

Signed-off-by: Yi Dong <[email protected]>

* fix non

Signed-off-by: Yi Dong <[email protected]>

* try openAI prompt

Signed-off-by: Yi Dong <[email protected]>

* remove unused imports

Signed-off-by: Yi Dong <[email protected]>

* remove human labels from the data

Signed-off-by: Yi Dong <[email protected]>

* use hf dataset to clean

Signed-off-by: Yi Dong <[email protected]>

* reviewer comments

Signed-off-by: Yi Dong <[email protected]>

---------

Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 authored and yaoyu-33 committed Oct 13, 2023
1 parent b7bcf08 commit b3da442
Show file tree
Hide file tree
Showing 9 changed files with 488 additions and 238 deletions.
7 changes: 6 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import datetime
import os
import threading
from functools import partial
Expand Down Expand Up @@ -167,7 +168,11 @@ def remove_padded_prompts(response, nb_paddings):
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer, callbacks=[CustomProgressBar()])
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=[CustomProgressBar()],
)

if cfg.gpt_model_file is not None:
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ model:

data:
chat: False # whether use chatbot data or not
chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '<im end><im start>', the '><' sometimes is merged to be a single token. This is not supported, try to avoid
system_turn_start: '<extra_id_0>'
turn_start: '<extra_id_1>'
label_start: '<extra_id_2>'
end_of_turn: "\x0A" # \0x0A is '\n'
end_of_name: "\x0A" # \0x0A is '\n'
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
12 changes: 11 additions & 1 deletion examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import os
import tempfile

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset import get_prompt_template_example
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import (
Expand All @@ -36,6 +37,8 @@
from nemo.utils.exp_manager import exp_manager
from nemo.utils.model_utils import inject_model_parallel_rank

mp.set_start_method("spawn", force=True)


def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
"""
Expand Down Expand Up @@ -71,6 +74,13 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

if cfg.model.data.get('chat', False):
# chat model, overwrite the prompt template
prompt_template = get_prompt_template_example(cfg.model.data.chat_prompt_tokens)
gpt_cfg.data.train_ds.prompt_template = prompt_template
gpt_cfg.data.validation_ds.prompt_template = prompt_template
gpt_cfg.data.test_ds.prompt_template = prompt_template

sft_cls = MegatronGPTSFTModel
gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}"

Expand Down
Loading

0 comments on commit b3da442

Please sign in to comment.