Skip to content

Commit

Permalink
adapter inference for t5
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Sep 20, 2022
1 parent 9582bfe commit 8f05db9
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
inference:
greedy: True # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: True # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False


trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: 16 # 16, 32, or bf16

data:
test_ds: ???
num_workers: 1
global_batch_size: 4
micro_batch_size: 4

tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
pipeline_model_parallel_split_rank: 0 # used for encoder and decoder model
pretrained_language_model_file: ??? # GPT nemo file path # used when starting from a .nemo file
adapter_model_file: ??? # .nemo file saved during training (using megatron_gpt_adapter_tuning.py)
output_file: null # save predictions to this file
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading
server: False # whether launch the inference server
port: 5555 # the port number for the inference server
batch_size: 8

141 changes: 141 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_t5_adapter_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from apex.transformer import parallel_state
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model import MegatronT5AdapterLearningModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
from nemo.core.config import hydra_runner
from nemo.utils.app_state import AppState

"""
This is the script to run an Adapter Tuned GPT Model for text generation.
Usage:
Assume the model has TP=1, PP=1 in the following use cases.
a. run greedy inference using a base gpt nemo file, and an adapter nemo file:
python megatron_gpt_ia3_eval.py \
gpt_model_file=PATH TO GPT MODEL NEMO FILE \
adapter_model_file=PATH TO ADAPTER MODEL NEMO FILE (generated by training script: ./megatron_gpt_ia3_tuning.py) \
data_paths=[PATH TO A JSONL FILE CONTAINING PROMPTS], \
output_file=PATH TO OUTPUT FILE TO DUMP PREDICTIONS
"""

if not torch.cuda.is_available():
raise EnvironmentError("GPU is needed for the inference")


@hydra_runner(config_path="conf", config_name="megatron_t5_adapter_inference")
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)

app_state = AppState()
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank,
)

# Load an adapter model, must be provided in config
if cfg.get("adapter_model_file", None) is not None and cfg.get("pretrained_language_model_file", None) is not None:
# Update frozen GPT model path in case it has changed
ia3_tuning_cfg = MegatronT5AdapterLearningModel.restore_from(
cfg.adapter_model_file, trainer=trainer, return_config=True
)
with open_dict(ia3_tuning_cfg):
ia3_tuning_cfg.pretrained_language_model_path = cfg.pretrained_language_model_file
ia3_tuning_cfg.micro_batch_size = cfg.get("micro_batch_size", 4)
ia3_tuning_cfg.global_batch_size = cfg.get("global_batch_size", 4)

# Now load prompt learning model with frozen gpt model base
model = MegatronT5AdapterLearningModel.restore_from(
restore_path=cfg.adapter_model_file, trainer=trainer, override_config_path=ia3_tuning_cfg
)

# Or load regular GPT model
else:
raise NotImplementedError(
"This script is meant for inference from an Infused Adapter Tuned T5 Model, config should contain an adapter_model_file and a pretrained_lanugage_model_file"
)

# check whether the DDP is initialized
if parallel_state.is_unitialized():

def dummy():
return

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()

model.freeze()

# Have to turn off activations_checkpoint_method for inference
try:
model.model.language_model.encoder.activations_checkpoint_method = None
except AttributeError:
pass

try:
model.frozen_model.model.language_model.encoder.activations_checkpoint_method = None
except AttributeError:
pass

test_ds, test_dl = model.build_virtual_prompt_dataset(
dataset_paths=cfg.data.test_ds,
batch_size=cfg.data.global_batch_size,
for_train=False,
drop_last=False,
shuffle=False,
num_workers=cfg.data.num_workers,
pin_memory=True,
)

config = OmegaConf.to_container(cfg.inference)
model.set_inference_config(config)
response = trainer.predict(model, test_dl)
print("***************************")
if cfg.output_file is not None:
with open(cfg.output_file, "w", encoding="utf-8") as f:
for batch in response:
for inp, pred in zip(batch['enc_input'], batch['predicted_token_ids']):
inp = ' '.join(inp.split('\n'))
pred = ' '.join(pred.split('\n'))
f.write(f'{inp} {pred}\n')
print("predictions saved to {}".format(cfg.output_file))
else:
print(response)
print("***************************")


if __name__ == '__main__':
main() # noqa pylint: disable=no-value-for-parameter

0 comments on commit 8f05db9

Please sign in to comment.