Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the retrieval services for microservice architecture #5910

Merged
merged 19 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 3 additions & 36 deletions examples/nlp/language_modeling/conf/megatron_retro_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,39 +39,6 @@ retrieval_service:
frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens
pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once
store_retrieved: False # whether store the retrieved documents, so it can be checked
weights: [0.5, 0.5] # weight for different retrieval services
sentence_bert: # define a few sentence bert models for different retrieval services to use
default:
devices: '0,1,2'
sentence_bert: 'all-mpnet-base-v2'
sentence_bert_batch: 4
qa_ctx:
devices: '0,1,2'
sentence_bert: 'facebook-dpr-ctx_encoder-multiset-base'
sentence_bert_batch: 4
qa_question:
devices: '0,1,2'
sentence_bert: 'facebook-dpr-question_encoder-multiset-base'
sentence_bert_batch: 4
services:
- type: FaissRetrievalService
faiss_devices: '0,1,2'
faiss_index: null # the faiss index file that is used to find KNN
nprobe: 100
retrieval_index: null
query_bert: 'default' # the bert model to encode the query str
- type: DynamicFaissRetrievalService
faiss_devices: '0,1,2'
faiss_index: null # the faiss index to load from file, if null, start from scratch
store_file: null # the retrieval service storage to load from file, if null, start from scratch
chunk_size: 64
stride: 32
ctx_bert: 'qa_ctx' # the bert model to encode the ctx that is used to construct the dynamic retrieval index
query_bert: 'qa_question' # the bert model to encode the query str
output_filename: 'dynamic_db' # the filename of serialized dynamic retrieval service, used for both Faiss index and data storage
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
share: False # whether create a public URL
username: test # user name for web client
password: test2 # password for web client
combo_service:
service_ip: '0.0.0.0'
service_port: 17181
55 changes: 16 additions & 39 deletions examples/nlp/language_modeling/megatron_retro_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@
# limitations under the License.

import os
import threading

import torch
from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.megatron_web_server import get_retro_demo
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer
from nemo.collections.nlp.modules.common.text_generation_utils import generate
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
Expand Down Expand Up @@ -114,41 +109,23 @@ def dummy():
retrieval_service = OmegaConf.to_container(cfg.retrieval_service)
model.set_inference_config(config, retrieval_service)

# running text generation, use inference server
if cfg.server:
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0:
if cfg.web_server:
thread = threading.Thread(
target=get_retro_demo, daemon=True, args=(cfg.share, cfg.username, cfg.password)
)
thread.start()
server = MegatronServer(model.cuda(), inference_strategy=model.inference_strategy)
server.run("0.0.0.0", port=cfg.port)

while True:
choice = torch.cuda.LongTensor(1)
torch.distributed.broadcast(choice, 0)
if choice[0].item() == 0:
generate(model.cuda(), strategy=model.inference_strategy)
if not cfg.use_predict_method:
# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts),
length_params=length_params,
sampling_params=sampling_params,
strategy=model.inference_strategy,
)
else:

if not cfg.use_predict_method:
# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts),
length_params=length_params,
sampling_params=sampling_params,
strategy=model.inference_strategy,
)
else:
# Second method of running text generation, call trainer.predict
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size)
response = trainer.predict(model, request_dl)

print("***************************")
print(response)
print("***************************")
# Second method of running text generation, call trainer.predict
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size)
response = trainer.predict(model, request_dl)

print("***************************")
print(response)
print("***************************")


if __name__ == '__main__':
Expand Down
Loading