Skip to content

Commit

Permalink
refactor pipelines. Add pipeline factory
Browse files Browse the repository at this point in the history
Signed-off-by: Greg Clark <[email protected]>
  • Loading branch information
messiaen committed Jul 19, 2023
1 parent d52d38c commit d2e58cb
Show file tree
Hide file tree
Showing 10 changed files with 627 additions and 725 deletions.
25 changes: 1 addition & 24 deletions gpt_infer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,11 @@ inference:
min_tokens_to_generate: 1 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False


trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: 16 # 16, 32, or bf16

tensor_model_parallel_size: -1
pipeline_model_parallel_size: -1
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others)
gpt_model_file: /models/gpt2b
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading
prompts: # prompts for GPT inference
- "Q: How are you?"
- "Q: How big is the universe?"
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server
chat: False # use the chat interface
chatbot_config:
value: False # whether to inject the value attributes
user: User
assistant: Assistant
system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
model_path: /models/gpt2b
6 changes: 4 additions & 2 deletions mvp_gpt_infer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from nemo.core.classes.modelPT import ModelPT


def main():
model = ModelPT.auto_load("/models/gpt2b", trainer_args={"devices": 1, "num_nodes": 1, "accelerator": "gpu", "logger": False, "precision": 16})
model = ModelPT.load_for_inference("gpt_infer.yaml")

output = model.generate_text(["Deep learning is"], end_strings=["."])
output = model.generate_text(["Deep learning is"], end_strings=["."], tokens_to_generate=50)
print(output)


if __name__ == "__main__":
main()
40 changes: 17 additions & 23 deletions nemo/collections/asr/modules/transformer/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,21 @@ def generate(
raise NotImplementedError("please implement this method")

def generate_text(
self,
prompts: List[str],
max_length: int = 10,
min_length: int = 1,
use_greedy: bool = False,
temperature: float = 0.5,
top_k: int = 1,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
add_BOS: bool = False,
all_probs: bool = False,
compute_logprob: bool = True,
end_strings: List[str] = [],
self,
prompts: List[str],
max_length: int = 10,
min_length: int = 1,
use_greedy: bool = False,
temperature: float = 0.5,
top_k: int = 1,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
add_BOS: bool = False,
all_probs: bool = False,
compute_logprob: bool = True,
end_strings: List[str] = [],
) -> OutputType:
length_param: LengthParam = {
"max_length": max_length,
"min_length": min_length
}
length_param: LengthParam = {"max_length": max_length, "min_length": min_length}

sample_param: SamplingParam = {
"use_greedy": use_greedy,
Expand All @@ -129,12 +126,9 @@ def generate_text(
"repetition_penalty": repetition_penalty,
"add_BOS": add_BOS,
"all_probs": all_probs,
"compute_logprob": compute_logprob
"compute_logprob": compute_logprob,
}

return self.generate(
prompts,
length_params=length_param,
sampling_params=sample_param,
end_strings=end_strings,
)
prompts, length_params=length_param, sampling_params=sample_param, end_strings=end_strings,
)
114 changes: 69 additions & 45 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Union

import numpy as np
import torch
from omegaconf.dictconfig import DictConfig
from omegaconf import DictConfig, OmegaConf, open_dict
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.plugins.environments import LightningEnvironment
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector, SaveRestoreConnector
from nemo.core.classes import ModelPT
from pytorch_lightning.trainer.trainer import Trainer
from typing_extensions import override

from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import (
MegatronPretrainingRandomSampler,
Expand Down Expand Up @@ -57,9 +55,18 @@
SamplingParam,
TextGeneration,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.collections.nlp.pipelines.text_generation_pipeline import (
TextGenerationPipeline,
TextGenerationStage,
TextGenerattionPostProcStage,
TextGenerattionPreProcStage,
load_tokenizer,
)
from nemo.core.classes import Exportable
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.inference_pipeline import InferencePipeline, InferencePipelineFactory, PipelineStageType
from nemo.core.neural_types import ChannelType, NeuralType
from nemo.utils import logging

Expand Down Expand Up @@ -189,7 +196,34 @@ def output_names(self) -> List[str]:
return ['logits']


class MegatronGPTModel(MegatronBaseModel, TextGeneration):
class MegatronGPTTextGenerationInferencePipeline(TextGenerationPipeline):
def load_nemo_pipeline(self, parts: Optional[List[Union[str, PipelineStageType]]] = None):
cfg = self.inference_config
if parts is None:
return
tokenizer = None
if "preprocessor" in parts or "postprocessor" in parts or PipelineStageType.NEMO_PROC in parts:
tokenizer_cfg = self.model_config.tokenizer
with open_dict(tokenizer_cfg):
# TODO pass model_path field name
if tokenizer_cfg.model is not None and tokenizer_cfg.model.startswith("nemo:"):
tokenizer_cfg.model = os.path.join(cfg.model_path, tokenizer_cfg.model.split(":", 1)[1])
if tokenizer_cfg.vocab_file is not None and tokenizer_cfg.vocab_file.startswith("nemo:"):
tokenizer_cfg.vocab_file = os.path.join(cfg.model_path, tokenizer_cfg.vocab_file.split(":", 1)[1])
if tokenizer_cfg.merge_file is not None and tokenizer_cfg.merge_file.startswith("nemo:"):
tokenizer_cfg.merge_file = os.path.join(cfg.model_path, tokenizer_cfg.merge_file.split(":", 1)[1])

tokenizer = load_tokenizer(tokenizer_cfg)
if "preprocessor" in parts or PipelineStageType.NEMO_PROC in parts:
self.set_stage_exec("preprocessor", TextGenerattionPreProcStage(tokenizer, MegatronGPTModel))
if "postprocessor" in parts or PipelineStageType.NEMO_PROC in parts:
self.set_stage_exec("postprocessor", TextGenerattionPostProcStage(tokenizer))
if "text_generation" in parts or PipelineStageType.MULTI_STEP_NNET in parts:
model = MegatronGPTModel.load_for_inference(self.inference_config)
self.set_stage_exec("text_generation", TextGenerationStage(model))


class MegatronGPTModel(MegatronBaseModel, TextGeneration, InferencePipelineFactory):
"""
Megatron GPT pretraining
"""
Expand Down Expand Up @@ -1140,7 +1174,9 @@ def dummy():
if length_params is None:
length_params = get_default_length_params()

return megatron_gpt_generate(self.cuda(), inputs, self.tokenizer, length_params, sampling_params, end_strings=end_strings)
return megatron_gpt_generate(
self.cuda(), inputs, self.tokenizer, length_params, sampling_params, end_strings=end_strings
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any:
inference_config = self.get_inference_config()
Expand Down Expand Up @@ -1335,50 +1371,14 @@ def _restore_sequence_parallelism_args(self):
for mod in module.modules():
if hasattr(mod, "sequence_parallel"):
mod.sequence_parallel = self.last_sequence_parallel

@classmethod
def auto_load(
cls,
restore_path: str,
trainer_args: Dict = {},
):
#cfg = ModelPT.restore_from(restore_path=restore_path, return_config=True)
trainer = Trainer(plugins=[LightningEnvironment()], strategy=NLPDDPStrategy(), **trainer_args)
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(restore_path):
save_restore_connector.model_extracted_dir = restore_path

pretrained_cfg = MegatronGPTModel.restore_from(
restore_path=restore_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
OmegaConf.set_struct(pretrained_cfg, True)
with open_dict(pretrained_cfg):
pretrained_cfg.sequence_parallel = False
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
pretrained_cfg.precision = trainer.precision
if trainer.precision == "16":
pretrained_cfg.megatron_amp_O2 = False
model = MegatronGPTModel.restore_from(
restore_path=restore_path,
trainer=trainer,
override_config_path=pretrained_cfg,
save_restore_connector=save_restore_connector,
map_location=f'cuda:{trainer.local_rank}',
)

return model

@classmethod
def load_for_inference(cls, config: Union[str, DictConfig]):
if isinstance(config, str):
cfg = OmegaConf.load(config)
else:
cfg = config

trainer = Trainer(plugins=[LightningEnvironment()], strategy=NLPDDPStrategy(), **cfg.trainer)
model_path = cfg.model_path
save_restore_connector = NLPSaveRestoreConnector()
Expand Down Expand Up @@ -1406,4 +1406,28 @@ def load_for_inference(cls, config: Union[str, DictConfig]):
save_restore_connector=save_restore_connector,
map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models
)
return model

if parallel_state.is_unitialized():

def dummy():
return

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()

if model.cfg.get('transformer_engine', False):
model.setup_transformer_engine_tp_groups()
return model

@override
@classmethod
def inference_pipeline(
cls,
task_name: Optional[str] = None,
inference_config: Optional[DictConfig] = None,
model_config: Optional[DictConfig] = None,
) -> InferencePipeline:
if task_name != "text_completion":
raise NotImplementedError(f"No pipeline for task {task_name}")
return MegatronGPTTextGenerationInferencePipeline(inference_config, model_config)
38 changes: 37 additions & 1 deletion nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tokenize_batch(self, sentences, max_len, add_BOS):
Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor.
"""
return self._tokenize_batch(self.model.tokenizer, sentences, max_len, add_BOS)

@classmethod
def _tokenize_batch(cls, tokenizer, sentences, max_len, add_BOS):
"""
Expand Down Expand Up @@ -341,6 +341,42 @@ def post_process(self, tokens: torch.Tensor, new_tokens: torch.Tensor, context_l
tokens[:, :context_length][(tokens[:, :context_length] >= pseudo_token_ids_start)] = tokenizer.unk_id


def model_static_inference_strategy_dispatcher(model):
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import (
MegatronGPTPromptLearningModel,
)
from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.retro_inference_strategies import (
RetroFileQAModelTextGenerationStrategy,
RetroModelTextGenerationStrategy,
RetroQAModelTextGenerationStrategy,
)

model_cls = model
if not isinstance(model_cls, type):
model_cls = type(model)

if issubclass(model_cls, MegatronGPTPromptLearningModel):
return PromptLearningModelTextGenerationStrategy
elif issubclass(model_cls, MegatronGPTModel):
return GPTModelTextGenerationStrategy
elif issubclass(model_cls, MegatronRetrievalModel):
strategy_name = args['strategy']
if strategy_name == 'RetroModelTextGenerationStrategy':
return RetroModelTextGenerationStrategy
elif strategy_name == 'RetroQAModelTextGenerationStrategy':
return RetroQAModelTextGenerationStrategy
elif strategy_name == 'RetroFileQAModelTextGenerationStrategy':
return RetroFileQAModelTextGenerationStrategy
else:
raise ValueError(f'{strategy_name} is not supported for inference')
else:
raise ValueError(f'{model} is not supported for inference')

# Should call GPTModel or Megatron Retrieval Model's forward method


def model_inference_strategy_dispatcher(model, **args):
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import (
Expand Down
Loading

0 comments on commit d2e58cb

Please sign in to comment.