diff --git a/docs/source/core/exp_manager.rst b/docs/source/core/exp_manager.rst index e813b8f16ac4..ce5f7a9cb087 100644 --- a/docs/source/core/exp_manager.rst +++ b/docs/source/core/exp_manager.rst @@ -248,6 +248,48 @@ You might also want to adjust the callback parameters: Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). +.. _exp_manager_straggler_det_support-label: + +.. note:: + Stragglers Detection feature is included in the optional NeMo resiliency package. + +Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. +NeMo provides a straggler detection feature that can identify slower GPUs. + +This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + +The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). +A performance score can be interpreted as the ratio of current performance to reference performance. + +There are two types of performance scores provided by the callback: + - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. + - Individual GPU performance score: The best historical performance of the GPU is used as a reference. + +Examples: + - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. + - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. + +If a GPU performance score drops below the specified threshold, it is identified as a straggler. + +To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. +You might also want to adjust the callback parameters: + +.. code-block:: yaml + + exp_manager: + ... + create_straggler_detection_callback: True + straggler_detection_callback_params: + report_time_interval: 300 # Interval [seconds] of the straggler check + calc_relative_gpu_perf: True # Calculate relative GPU performance + calc_individual_gpu_perf: True # Calculate individual GPU performance + num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected + gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores + gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores + stop_if_detected: True # Terminate the workload if stragglers are detected + +Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). + Fault Tolerance --------------- diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml index 2570251bcdee..ce8311daf95c 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_inference.yaml @@ -31,6 +31,7 @@ hparams_file: null # model configuration file, only used for PTL checkpoint load prompts: # prompts for GPT inference - "Q: How are you?" - "Q: How big is the universe?" +prompts_jsonl: null server: False # whether launch the API server port: 5555 # the port number for the inference server web_server: False # whether launch the web inference server diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index f3413a5fa92e..362a2ae3e298 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -14,6 +14,7 @@ import asyncio import datetime +import json import os import threading from functools import partial @@ -166,20 +167,7 @@ def remove_padded_prompts(response, nb_paddings): return result -@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") -def main(cfg) -> None: - - callbacks = [] - # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks - if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: - callbacks.append(CustomProgressBar()) - # trainer required for restoring model parallel models - trainer = Trainer( - strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), - **cfg.trainer, - callbacks=callbacks, - ) - +def load_model_from_config(trainer, cfg): if cfg.gpt_model_file is not None: if ( cfg.tensor_model_parallel_size < 0 @@ -285,7 +273,50 @@ def main(cfg) -> None: model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer) else: raise ValueError("need at least a nemo file or checkpoint dir") + return model + + +def load_prompts(cfg): + prompts = [] + if (cfg_prompts := getattr(cfg, 'prompts', None)) is not None: + prompts = OmegaConf.to_container(cfg_prompts) + if (prompts_jsonl := getattr(cfg, 'prompts_jsonl', None)) is not None: + with open(prompts_jsonl, 'rt') as fp: + try: + prompts += list(map(json.loads, map(str.rstrip, fp))) + except: + prompts += list(map(str.rstrip, fp)) + # Make sure non-empty input + assert len(prompts) > 0, "Expected at least one prompt" + # Make sure all have the same type + assert all( + map(lambda x: isinstance(x, type(prompts[0])), prompts) + ), "Expected all prompts to have the same datatype" + return prompts + + +def round_to_mult(n, mult=8): + """ + Rounds number n to be a multiple of mult + """ + return ((n + mult - 1) // mult) * mult + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_inference") +def main(cfg) -> None: + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + model = load_model_from_config(trainer, cfg) model.freeze() # Have to turn off activations_checkpoint_method for inference @@ -311,17 +342,17 @@ def main(cfg) -> None: "end_strings": cfg.inference.end_strings, } + prompts = load_prompts(cfg) + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) - if fp8_enabled: - nb_paddings = 0 - while len(cfg.prompts) % 8 != 0: - cfg.prompts.append("") - nb_paddings += 1 + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings # First method of running text generation, call model.generate method - response = model.generate( - inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params - ) + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) if fp8_enabled: response = remove_padded_prompts(response, nb_paddings) @@ -331,7 +362,7 @@ def main(cfg) -> None: # Second method of running text generation, call trainer.predict [recommended] bs = 8 if fp8_enabled else 2 - ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + ds = RequestDataSet(prompts) request_dl = DataLoader(dataset=ds, batch_size=bs) config = OmegaConf.to_container(cfg.inference) model.set_inference_config(config) diff --git a/nemo/collections/common/tokenizers/chat_template_mixin.py b/nemo/collections/common/tokenizers/chat_template_mixin.py new file mode 100644 index 000000000000..83a5e537519c --- /dev/null +++ b/nemo/collections/common/tokenizers/chat_template_mixin.py @@ -0,0 +1,179 @@ +import re +from functools import cache + +TEMPLATE_VAR_VALIDATION_PAT = re.compile(r'^\{_[A-Za-z][A-Za-z0-9_]*_\}$') +TEMPLATE_VAR_SEARCH_PAT = re.compile('({_[^}]+_})') + + +class ChatTemplateMixin: + def apply_chat_template(self, messages): + assert self.chat_template is not None + return tokenize_with_chat_template(self, messages, self.chat_template) + + @property + def has_chat_template(self): + return self.chat_template is not None + + +@cache +def is_template_var(s): + # It should start with {_ and end with _}, be non-empty and not contain { or } within. + return re.match(TEMPLATE_VAR_VALIDATION_PAT, s) + + +def extract_template_parts(template, skip_empty=True): + for part in re.split(TEMPLATE_VAR_SEARCH_PAT, template): + # skip empty parts + if skip_empty and part == '': + continue + yield part + + +def strip_template_wrap(s): + if not is_template_var(s): + return s + # Strip the "{_" prefix and the "_}" suffix + return s[2:-2] + + +def render_chat_turn(message, template): + """Renders a chat turn based on template + + Args: + message (Dict) + e.g. {'role': ['user'], 'content': ['What is your favourite fruit?']}, + template (Str): + "[INST] {_content_} [/INST]", + + Returns: + (str, token_id/None): the template formatted message + e.g. + "[INST] What is your favourite fruit? [/INST]", None + """ + ans = [] + for i, template_part in enumerate(extract_template_parts(template)): + if is_template_var(template_part): + template_part = strip_template_wrap(template_part) + if template_part == 'content': + ans.append(message['content']) + else: + # assert i == len(template_parts) - 1, "unsupported" + yield ''.join(ans), template_part + ans = [] + else: + # Otherwise it is literal string + ans.append(template_part) + yield ''.join(ans), None + + +def encode_string_with_special_token(tokenizer, inputs, special_token): + """ + Tokenizes a string or a list of string into their corresponding token_ids + and appends (at the end) a special_token if present. + + Args: + tokenizer: (SPM) + inputs: (Str, List[Str]) + e.g. "Alex" or ["Alex", "nvidia"] + special_token: (Str): + e.g. "eos" + + Returns: + (list[int]): list of token_ids + e.g. + input="Alex", special_token="eos" + Alex->[3413] + eos->[2] + + Will return the following: + [3413, 2] + """ + ans = [] + if isinstance(inputs, str) and inputs != '': + ans += tokenizer.text_to_ids(inputs) + elif isinstance(inputs, list) and len(inputs) > 0: + ans += tokenizer.text_to_ids(''.join(inputs)) + if special_token is not None: + # TODO(@akoumparouli): limit which attributes user-defined string can query. + assert hasattr(tokenizer, special_token), f"Special_token {special_token} is not part of tokenizer" + ans += [getattr(tokenizer, special_token)] + return ans + + +def tokenize_with_chat_template(tokenizer, messages, template): + assert is_chat_input(messages), "Expected input to be chat-template" + assert len(messages) > 0, "Expected non-empty messages" + assert 'roles' in template, "Expected template to have key `roles`." + ans = [] + encode = lambda x, y: encode_string_with_special_token(tokenizer, x, y) + if 'prefix' in template: + for part, special_token in render_chat_turn('', template['prefix']): + ans += encode(part, special_token) + buffer = [] + for message in messages: + assert message['role'] in template['roles'], (message['role'], template['roles']) + msg_template = template['roles'][message['role']] + for templated_messages, special_token in render_chat_turn(message, msg_template): + buffer += [templated_messages] + if special_token is not None: + ans += encode(buffer, special_token) + buffer = [] + # handle tail + ans += encode(buffer, None) + assert len(ans) > 0, 'Expected non-empty output' + return ans + + +def extract_turns(messages, axis): + """ + a collated messages can have multiple chat messages in each dict, + this extracts (vertically) one of them, for example: + + messages = [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + ans = extract_turns(messages, axis=1) + + ans = [ + {'role': ['user'], 'content': ['What is your favourite fruit?']}, + {'role': ['assistant'], 'content': ["good squeeze of fresh lemon"]}, + {'role': ['user'], 'content': ['Do you have tomato salad recipes?']} + ] + """ + ans = [] + for turn in messages: + ans.append({k: v[axis] for k, v in turn.items()}) + return ans + + +def explode_chat_template_input(messages): + """ + Example input + [ + {'role': ['user', 'user'], 'content': ['What is your favourite condiment?', 'What is your favourite fruit?']}, + {'role': ['assistant', 'assistant'], 'content': ["Well, I'm quite partial to a ", "good squeeze of fresh lemon"]}, + {'role': ['user', 'user'], 'content': ['Do you have mayonnaise recipes?', 'Do you have tomato salad recipes?']} + ] + + Notice the 2D axis system of the messages variable, one for the list and one for each item in the list (i.e. + the 'content' contains multiple messages). + """ + assert isinstance(messages, list), "Expected messages to be a list" + assert len(messages) > 0, "Expected non empty messages" + assert all(map(lambda x: isinstance(x, dict), messages)), "Expected messages to contain dicts" + assert all( + map(lambda x: 'role' in x and 'content' in x, messages) + ), "Expected messages each dict to contain 'role' and 'content' fields" + n = len(messages[0]['role']) + assert all( + map(lambda x: len(x['role']) == n, messages) + ), "Expected all batch messages to contain equal number of roles in all turns" + for i in range(n): + yield extract_turns(messages, axis=i) + + +def is_chat_input(messages): + # TOOD(@akoumparouli): improve validation. + return isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], dict) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index 4a47f0e49b1e..00893b6f379f 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -20,13 +20,14 @@ import torch from nemo.collections.common.parts.utils import if_exist +from nemo.collections.common.tokenizers.chat_template_mixin import ChatTemplateMixin from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging __all__ = ['SentencePieceTokenizer', 'create_spt_model'] -class SentencePieceTokenizer(TokenizerSpec): +class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): """ Sentencepiecetokenizer https://github.com/google/sentencepiece. @@ -38,8 +39,13 @@ class SentencePieceTokenizer(TokenizerSpec): """ def __init__( - self, model_path: str, special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, legacy: bool = False + self, + model_path: str, + special_tokens: Optional[Union[Dict[str, str], List[str]]] = None, + legacy: bool = False, + chat_template: Optional[Dict] = None, ): + self.chat_template = chat_template if not model_path or not os.path.exists(model_path): raise ValueError(f"model_path: {model_path} is invalid") self.tokenizer = sentencepiece.SentencePieceProcessor() @@ -89,6 +95,14 @@ def text_to_tokens(self, text): return self.tokenizer.encode_as_pieces(text) def text_to_ids(self, text, sample_alpha=None): + if isinstance(text, str): + return self._text_to_ids(text, sample_alpha) + elif isinstance(text, list): + return self.apply_chat_template(text) + else: + raise ValueError(f"Expected either str or list input, but got {type(text)}") + + def _text_to_ids(self, text, sample_alpha=None): if self.legacy: ids = [] idx = 0 diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index ae659e757496..f7b53a95c19a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -431,6 +431,7 @@ def _build_tokenizer(self): special_tokens=self.cfg.tokenizer.get('special_tokens', None), trust_remote_code=self.cfg.tokenizer.get('trust_remote_code', False), legacy=legacy, + chat_template=getattr(self._cfg.tokenizer, "chat_template", None), ) if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index e8e2859e439f..238c01695f42 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -21,6 +21,8 @@ import torch from transformers import CLIPImageProcessor + +from nemo.collections.common.tokenizers.chat_template_mixin import explode_chat_template_input, is_chat_input from nemo.collections.nlp.modules.common.lm_utils import pad_batch from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids @@ -94,7 +96,12 @@ def tokenize_batch(self, sentences, max_len, add_BOS): Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor. """ tokenizer = self.model.tokenizer - if add_BOS: + if is_chat_input(sentences): + assert getattr( + tokenizer, 'has_chat_template', False + ), "Got chat-template input but tokenizer does not support chat template formating." + context_tokens = list(map(tokenizer.text_to_ids, explode_chat_template_input(sentences))) + elif add_BOS: context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences] elif hasattr(tokenizer.tokenizer, "get_prefix_tokens"): # chatglm: add tokenizer.gmask_id, tokenizer.sop_id diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 498d9e9a09da..cd02f5409679 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -122,31 +122,26 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para compute_prob_response = get_computeprob_response(tokenizer, response, inputs) return compute_prob_response - if isinstance(inputs, (list, tuple)): - if isinstance(inputs[0], (str, torch.Tensor)): - output = generate( - model, - inputs=inputs, - tokens_to_generate=length_params['max_length'], - all_probs=sampling_params['all_probs'], - compute_logprob=sampling_params['compute_logprob'], - temperature=sampling_params['temperature'], - add_BOS=sampling_params['add_BOS'], - top_k=sampling_params['top_k'], - top_p=sampling_params['top_p'], - greedy=sampling_params['use_greedy'], - repetition_penalty=sampling_params['repetition_penalty'], - end_strings=sampling_params['end_strings'], - min_tokens_to_generate=length_params['min_length'], - **strategy_args, - ) - return output - elif isinstance(inputs[0], dict): - raise NotImplementedError("json object not implemented") - else: - raise NotImplementedError("unknown type is not implemented") - else: - raise NotImplementedError("unknown type is not implemented") + if not isinstance(inputs, (list, tuple)): + raise NotImplementedError(f"unknown type {type(inputs)} is not implemented") + + output = generate( + model, + inputs=inputs, + tokens_to_generate=length_params['max_length'], + all_probs=sampling_params['all_probs'], + compute_logprob=sampling_params['compute_logprob'], + temperature=sampling_params['temperature'], + add_BOS=sampling_params['add_BOS'], + top_k=sampling_params['top_k'], + top_p=sampling_params['top_p'], + greedy=sampling_params['use_greedy'], + repetition_penalty=sampling_params['repetition_penalty'], + end_strings=sampling_params['end_strings'], + min_tokens_to_generate=length_params['min_length'], + **strategy_args, + ) + return output def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_params, inference_config, **strategy_args): diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index 67c94ae5d608..d3ee69f75b25 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -78,6 +78,7 @@ def get_tokenizer( special_tokens: Optional[Dict[str, str]] = None, use_fast: Optional[bool] = False, bpe_dropout: Optional[float] = 0.0, + chat_template: Optional[Dict] = None, ): """ Args: @@ -91,7 +92,7 @@ def get_tokenizer( use_fast: (only for HuggingFace AutoTokenizer) set to True to use fast HuggingFace tokenizer bpe_dropout: (experimental) BPE dropout tries to corrupt the standard segmentation procedure of BPE to help - model better learn word compositionality and become robust to segmentation errors. + model better learn word compositionality and become robust to segmentation errors. It has emperically been shown to improve inference time BLEU scores. """ if special_tokens is None: @@ -116,7 +117,10 @@ def get_tokenizer( if tokenizer_name == 'sentencepiece': logging.info("tokenizer_model: " + str(tokenizer_model)) return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, special_tokens=special_tokens, legacy=True + model_path=tokenizer_model, + special_tokens=special_tokens, + legacy=True, + chat_template=chat_template, ) elif tokenizer_name == 'word': return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) @@ -151,6 +155,7 @@ def get_nmt_tokenizer( legacy: Optional[bool] = False, delimiter: Optional[str] = None, trust_remote_code: Optional[bool] = False, + chat_template: Optional[Dict] = None, ): """ Args: @@ -187,7 +192,9 @@ def get_nmt_tokenizer( elif library == 'sentencepiece': logging.info(f'Getting SentencePiece with model: {tokenizer_model}') return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer( - model_path=tokenizer_model, legacy=legacy + model_path=tokenizer_model, + legacy=legacy, + chat_template=chat_template, ) elif library == 'byte-level': logging.info(f'Using byte-level tokenization') @@ -209,7 +216,9 @@ def get_nmt_tokenizer( logging.info( f'Getting Megatron tokenizer for pretrained model name: {model_name}, custom vocab file: {vocab_file}, and merges file: {merges_file}' ) - return get_tokenizer(tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file) + return get_tokenizer( + tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template + ) elif library == 'tabular': return TabularTokenizer(vocab_file, delimiter=delimiter) else: