Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding server option to peft eval #7292

Merged
merged 4 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,68 @@ inference:
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
outfile_path: output.txt
compute_attention_mask: True
compute_attention_mask: True

# server-related configs
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: True # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
web_port: 9889 # the port number of the web server 1058
chat: False # use the chat interface
chatbot_config:
value: False # whether to inject the value attributes
attributes:
- name: Quality
min: 0
max: 4
key: quality
type: int
default: 4
- name: Toxicity
min: 0
max: 4
key: toxcity
type: int
default: 0
- name: Humor
min: 0
max: 4
key: humor
type: int
default: 0
- name: Creativity
min: 0
max: 4
key: creativity
type: int
default: 0
- name: Violence
min: 0
max: 4
key: violence
type: int
default: 0
- name: Helpfulness
min: 0
max: 4
key: helpfulness
type: int
default: 4
- name: Not_Appropriate
min: 0
max: 4
key: not_appropriate
type: int
default: 0
- name: Language
choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh']
key: lang
type: list
default: en

user: User
assistant: Assistant
system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
51 changes: 51 additions & 0 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.


import asyncio
import json
import os
import threading
from functools import partial

import torch
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
Expand All @@ -25,6 +29,8 @@
from nemo.collections.nlp.models.language_modeling.megatron_gpt_peft_models import MegatronGPTPEFTModel
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
Expand All @@ -36,6 +42,13 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging

try:
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

mp.set_start_method("spawn", force=True)
"""
This is the script to run inference with a PEFT model or an SFT Model.
Expand Down Expand Up @@ -179,6 +192,44 @@

trainer.test(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we skip traininer.test is cfg.server is used?


if cfg.server:
if not HAVE_MEGATRON_CORE:
raise ValueError('Megatron-core needs to be installed to use this feature!')

from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo

if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
if cfg.chat:
defaults = {
'user': cfg.chatbot_config.user,
'assistant': cfg.chatbot_config.assistant,
'system': cfg.chatbot_config.system,
}
web_ui = partial(
get_chatbot_demo,
defaults=defaults,
value=cfg.chatbot_config.value,
attributes=cfg.chatbot_config.attributes,
)
else:
web_ui = get_demo
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=web_ui,
daemon=True,
args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
thread.start()
server = MegatronServer(model.cuda())
server.run("0.0.0.0", port=cfg.port)

while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
generate(model.cuda())


if __name__ == "__main__":
main()
Loading