Skip to content

Commit

Permalink
Merge branch 'main' into akoumparouli/mistral_import_instruct_chat_te…
Browse files Browse the repository at this point in the history
…mplate_fix
  • Loading branch information
akoumpa committed Jul 4, 2024
2 parents a2afc5c + bf82737 commit 564104e
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 55 deletions.
42 changes: 42 additions & 0 deletions docs/source/core/exp_manager.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,48 @@ You might also want to adjust the callback parameters:
Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes).

.. _exp_manager_straggler_det_support-label:

.. note::
Stragglers Detection feature is included in the optional NeMo resiliency package.

Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process.
NeMo provides a straggler detection feature that can identify slower GPUs.

This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default.

The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best).
A performance score can be interpreted as the ratio of current performance to reference performance.

There are two types of performance scores provided by the callback:
- Relative GPU performance score: The best-performing GPU in the workload is used as a reference.
- Individual GPU performance score: The best historical performance of the GPU is used as a reference.

Examples:
- If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU.
- If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance.

If a GPU performance score drops below the specified threshold, it is identified as a straggler.

To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file.
You might also want to adjust the callback parameters:

.. code-block:: yaml
exp_manager:
...
create_straggler_detection_callback: True
straggler_detection_callback_params:
report_time_interval: 300 # Interval [seconds] of the straggler check
calc_relative_gpu_perf: True # Calculate relative GPU performance
calc_individual_gpu_perf: True # Calculate individual GPU performance
num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected
gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores
gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores
stop_if_detected: True # Terminate the workload if stragglers are detected
Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes).

Fault Tolerance
---------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ hparams_file: null # model configuration file, only used for PTL checkpoint load
prompts: # prompts for GPT inference
- "Q: How are you?"
- "Q: How big is the universe?"
prompts_jsonl: null
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
Expand Down
77 changes: 54 additions & 23 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import datetime
import json
import os
import threading
from functools import partial
Expand Down Expand Up @@ -166,20 +167,7 @@ def remove_padded_prompts(response, nb_paddings):
return result


@hydra_runner(config_path="conf", config_name="megatron_gpt_inference")
def main(cfg) -> None:

callbacks = []
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())
# trainer required for restoring model parallel models
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=callbacks,
)

def load_model_from_config(trainer, cfg):
if cfg.gpt_model_file is not None:
if (
cfg.tensor_model_parallel_size < 0
Expand Down Expand Up @@ -285,7 +273,50 @@ def main(cfg) -> None:
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer)
else:
raise ValueError("need at least a nemo file or checkpoint dir")
return model


def load_prompts(cfg):
prompts = []
if (cfg_prompts := getattr(cfg, 'prompts', None)) is not None:
prompts = OmegaConf.to_container(cfg_prompts)
if (prompts_jsonl := getattr(cfg, 'prompts_jsonl', None)) is not None:
with open(prompts_jsonl, 'rt') as fp:
try:
prompts += list(map(json.loads, map(str.rstrip, fp)))
except:
prompts += list(map(str.rstrip, fp))
# Make sure non-empty input
assert len(prompts) > 0, "Expected at least one prompt"
# Make sure all have the same type
assert all(
map(lambda x: isinstance(x, type(prompts[0])), prompts)
), "Expected all prompts to have the same datatype"
return prompts


def round_to_mult(n, mult=8):
"""
Rounds number n to be a multiple of mult
"""
return ((n + mult - 1) // mult) * mult


@hydra_runner(config_path="conf", config_name="megatron_gpt_inference")
def main(cfg) -> None:

callbacks = []
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar:
callbacks.append(CustomProgressBar())
# trainer required for restoring model parallel models
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=callbacks,
)

model = load_model_from_config(trainer, cfg)
model.freeze()

# Have to turn off activations_checkpoint_method for inference
Expand All @@ -311,17 +342,17 @@ def main(cfg) -> None:
"end_strings": cfg.inference.end_strings,
}

prompts = load_prompts(cfg)

fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True)
if fp8_enabled:
nb_paddings = 0
while len(cfg.prompts) % 8 != 0:
cfg.prompts.append("")
nb_paddings += 1
if fp8_enabled and len(prompts) > 0:
padded_len = round_to_mult(len(prompts), 8)
nb_paddings = padded_len - len(prompts)
if nb_paddings > 0:
nb_paddings += [''] * nb_paddings

# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params
)
response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params)

if fp8_enabled:
response = remove_padded_prompts(response, nb_paddings)
Expand All @@ -331,7 +362,7 @@ def main(cfg) -> None:

# Second method of running text generation, call trainer.predict [recommended]
bs = 8 if fp8_enabled else 2
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
ds = RequestDataSet(prompts)
request_dl = DataLoader(dataset=ds, batch_size=bs)
config = OmegaConf.to_container(cfg.inference)
model.set_inference_config(config)
Expand Down
179 changes: 179 additions & 0 deletions nemo/collections/common/tokenizers/chat_template_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import re
from functools import cache

TEMPLATE_VAR_VALIDATION_PAT = re.compile(r'^\{_[A-Za-z][A-Za-z0-9_]*_\}$')
TEMPLATE_VAR_SEARCH_PAT = re.compile('({_[^}]+_})')


class ChatTemplateMixin:
def apply_chat_template(self, messages):
assert self.chat_template is not None
return tokenize_with_chat_template(self, messages, self.chat_template)

@property
def has_chat_template(self):
return self.chat_template is not None


@cache
def is_template_var(s):
# It should start with {_ and end with _}, be non-empty and not contain { or } within.
return re.match(TEMPLATE_VAR_VALIDATION_PAT, s)


def extract_template_parts(template, skip_empty=True):
for part in re.split(TEMPLATE_VAR_SEARCH_PAT, template):
# skip empty parts
if skip_empty and part == '':
continue
yield part


def strip_template_wrap(s):
if not is_template_var(s):
return s
# Strip the "{_" prefix and the "_}" suffix
return s[2:-2]


def render_chat_turn(message, template):
"""Renders a chat turn based on template
Args:
message (Dict)
e.g. {'role': ['user'], 'content': ['What is your favourite fruit?']},
template (Str):
"[INST] {_content_} [/INST]",
Returns:
(str, token_id/None): the template formatted message
e.g.
"[INST] What is your favourite fruit? [/INST]", None
"""
ans = []
for i, template_part in enumerate(extract_template_parts(template)):
if is_template_var(template_part):
template_part = strip_template_wrap(template_part)
if template_part == 'content':
ans.append(message['content'])
else:
# assert i == len(template_parts) - 1, "unsupported"
yield ''.join(ans), template_part
ans = []
else:
# Otherwise it is literal string
ans.append(template_part)
yield ''.join(ans), None


def encode_string_with_special_token(tokenizer, inputs, special_token):
"""
Tokenizes a string or a list of string into their corresponding token_ids
and appends (at the end) a special_token if present.
Args:
tokenizer: (SPM)
inputs: (Str, List[Str])
e.g. "Alex" or ["Alex", "nvidia"]
special_token: (Str):
e.g. "eos"
Returns:
(list[int]): list of token_ids
e.g.
input="Alex", special_token="eos"
Alex->[3413]
eos->[2]
Will return the following:
[3413, 2]
"""
ans = []
if isinstance(inputs, str) and inputs != '':
ans += tokenizer.text_to_ids(inputs)
elif isinstance(inputs, list) and len(inputs) > 0:
ans += tokenizer.text_to_ids(''.join(inputs))
if special_token is not None:
# TODO(@akoumparouli): limit which attributes user-defined string can query.
assert hasattr(tokenizer, special_token), f"Special_token {special_token} is not part of tokenizer"
ans += [getattr(tokenizer, special_token)]
return ans


def tokenize_with_chat_template(tokenizer, messages, template):
assert is_chat_input(messages), "Expected input to be chat-template"
assert len(messages) > 0, "Expected non-empty messages"
assert 'roles' in template, "Expected template to have key `roles`."
ans = []
encode = lambda x, y: encode_string_with_special_token(tokenizer, x, y)
if 'prefix' in template:
for part, special_token in render_chat_turn('', template['prefix']):
ans += encode(part, special_token)
buffer = []
for message in messages:
assert message['role'] in template['roles'], (message['role'], template['roles'])
msg_template = template['roles'][message['role']]
for templated_messages, special_token in render_chat_turn(message, msg_template):
buffer += [templated_messages]
if special_token is not None:
ans += encode(buffer, special_token)
buffer = []
# handle tail
ans += encode(buffer, None)
assert len(ans) > 0, 'Expected non-empty output'
return ans


def extract_turns(messages, axis):
"""
a collated messages can have multiple chat messages in each dict,
this extracts (vertically) one of them, for example:
messages = [
{'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']},
{'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]},
{'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']}
]
ans = extract_turns(messages, axis=1)
ans = [
{'role': ['user'], 'content': ['What is your favourite fruit?']},
{'role': ['assistant'], 'content': ["good squeeze of fresh lemon"]},
{'role': ['user'], 'content': ['Do you have tomato salad recipes?']}
]
"""
ans = []
for turn in messages:
ans.append({k: v[axis] for k, v in turn.items()})
return ans


def explode_chat_template_input(messages):
"""
Example input
[
{'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']},
{'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]},
{'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']}
]
Notice the 2D axis system of the messages variable, one for the list and one for each item in the list (i.e.
the 'content' contains multiple messages).
"""
assert isinstance(messages, list), "Expected messages to be a list"
assert len(messages) > 0, "Expected non empty messages"
assert all(map(lambda x: isinstance(x, dict), messages)), "Expected messages to contain dicts"
assert all(
map(lambda x: 'role' in x and 'content' in x, messages)
), "Expected messages each dict to contain 'role' and 'content' fields"
n = len(messages[0]['role'])
assert all(
map(lambda x: len(x['role']) == n, messages)
), "Expected all batch messages to contain equal number of roles in all turns"
for i in range(n):
yield extract_turns(messages, axis=i)


def is_chat_input(messages):
# TOOD(@akoumparouli): improve validation.
return isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict)
18 changes: 16 additions & 2 deletions nemo/collections/common/tokenizers/sentencepiece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import torch

from nemo.collections.common.parts.utils import if_exist
from nemo.collections.common.tokenizers.chat_template_mixin import ChatTemplateMixin
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging

__all__ = ['SentencePieceTokenizer', 'create_spt_model']


class SentencePieceTokenizer(TokenizerSpec):
class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin):
"""
Sentencepiecetokenizer https://github.com/google/sentencepiece.
Expand All @@ -38,8 +39,13 @@ class SentencePieceTokenizer(TokenizerSpec):
"""

def __init__(
self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, legacy: bool = False
self,
model_path: str,
special_tokens: Optional[Union[Dict[str, str], List[str]]] = None,
legacy: bool = False,
chat_template: Optional[Dict] = None,
):
self.chat_template = chat_template
if not model_path or not os.path.exists(model_path):
raise ValueError(f"model_path: {model_path} is invalid")
self.tokenizer = sentencepiece.SentencePieceProcessor()
Expand Down Expand Up @@ -89,6 +95,14 @@ def text_to_tokens(self, text):
return self.tokenizer.encode_as_pieces(text)

def text_to_ids(self, text, sample_alpha=None):
if isinstance(text, str):
return self._text_to_ids(text, sample_alpha)
elif isinstance(text, list):
return self.apply_chat_template(text)
else:
raise ValueError(f"Expected either str or list input, but got {type(text)}")

def _text_to_ids(self, text, sample_alpha=None):
if self.legacy:
ids = []
idx = 0
Expand Down
Loading

0 comments on commit 564104e

Please sign in to comment.