From 582d6c1b5fbdfde7bfa4b74af3466dd0ab848f94 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 31 Jan 2023 22:44:56 +0000 Subject: [PATCH 01/17] retrieval service seperation Signed-off-by: Yi Dong --- .../modules/common/megatron/bert_service.py | 37 ++----- .../common/megatron/retrieval_service.py | 101 ++++++------------ .../conf/bert_service.yaml | 13 +++ .../conf/dynamic_retrieval_service.yaml | 21 ++++ .../conf/static_retrieval_service.yaml | 15 +++ .../start_bert_service.py | 51 +++++++++ .../start_dynamic_retrieval_service.py | 59 ++++++++++ .../start_static_retrieval_service.py | 55 ++++++++++ 8 files changed, 255 insertions(+), 97 deletions(-) create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/bert_service.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py diff --git a/nemo/collections/nlp/modules/common/megatron/bert_service.py b/nemo/collections/nlp/modules/common/megatron/bert_service.py index 658aab825194..0c9f6e07ee2d 100644 --- a/nemo/collections/nlp/modules/common/megatron/bert_service.py +++ b/nemo/collections/nlp/modules/common/megatron/bert_service.py @@ -16,7 +16,6 @@ import pickle import threading import time -from collections import OrderedDict from typing import List, Union import torch @@ -26,14 +25,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -PORT_NUM_START = 17190 -# a global dict, map bert name to port number -BERT_MODEL_MAP = OrderedDict() - - -def get_available_port_num(): - output = PORT_NUM_START - return output + len(BERT_MODEL_MAP) +BERT_RETRIEVER_PORT_NUM = 17190 class SentenceBertResource(Resource): @@ -108,8 +100,9 @@ def __init__( resource_class_args=[self.bert_model, self.tokenizer, self.pool, self.sentence_bert_batch,], ) - def run(self, url): - port = BERT_MODEL_MAP[self.name] + def run(self, url, port=None): + if port is None: + port = BERT_RETRIEVER_PORT_NUM threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() @@ -119,6 +112,7 @@ def start_sentence_bert_server( tokenizer: TokenizerSpec, sentence_bert: str = 'all-mpnet-base-v2', sentence_bert_batch: int = 4, + port: int = None, ): """ Start the sentence bert server method. @@ -126,20 +120,7 @@ def start_sentence_bert_server( Doesn't support multiple nodes yet. """ # register the bert model port number - port_num = get_available_port_num() - BERT_MODEL_MAP[name] = port_num - - if torch.distributed.is_initialized(): - # doesn't handle multiple nodes yet. - # need to set ip address properly for it to work in multiple nodes environment - if torch.distributed.get_rank() == 0: - server = SentenceBertServer(name, devices, tokenizer, sentence_bert, sentence_bert_batch,) - server.run("0.0.0.0") - # sleep to make sure the sentence bert server is full started. - time.sleep(2) - torch.distributed.barrier() - else: - server = SentenceBertServer(name, devices, tokenizer, sentence_bert, sentence_bert_batch,) - server.run("0.0.0.0") - # sleep to make sure the sentence bert server is full started. - time.sleep(2) + server = SentenceBertServer(name, devices, tokenizer, sentence_bert, sentence_bert_batch,) + server.run("0.0.0.0", port=port) + # sleep to make sure the sentence bert server is full started. + time.sleep(2) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_service.py index 1b439a94f85e..f982ca4ce381 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_service.py @@ -30,9 +30,8 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import MMapRetrievalIndexedDataset -from nemo.collections.nlp.modules.common.megatron.bert_service import BERT_MODEL_MAP -log = logging.getLogger('werkzeug') +log = logging.getLogger('retrieval') log.setLevel(logging.ERROR) lock = threading.Lock() @@ -42,8 +41,8 @@ PORT_NUM_DYN = 17180 -def request_data(data, port=PORT_NUM): - resp = requests.put('http://localhost:{}/knn'.format(port), data=json.dumps(data), headers=headers) +def request_data(data, ip='localhost', port=None): + resp = requests.put(f'http://{ip}:{port}/knn', data=json.dumps(data), headers=headers) return resp.json() @@ -98,12 +97,13 @@ class FaissRetrievalResource(Resource): """ def __init__( - self, index, tokenizer, ds, query_bert_port, + self, index, tokenizer, ds, query_bert_ip, query_bert_port, ): # server self.index = index self.tokenizer = tokenizer self.ds = ds + self.query_bert_ip = query_bert_ip self.query_bert_port = query_bert_port def put(self): @@ -126,7 +126,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int): text = self.tokenizer.ids_to_text(q) sentence_list.append(text) query = sentence_list - emb = request_data(query, self.query_bert_port) + emb = request_data(query, self.query_bert_ip, self.query_bert_port) emb_data = base64.b64decode(emb.encode()) emb = pickle.loads(emb_data) if self.index.ntotal == 0: @@ -160,7 +160,8 @@ def __init__( nprobe: int, retrieval_index: str, tokenizer: TokenizerSpec, - query_bert_port: int, + query_bert_ip: str, + query_bert_port: int = None, ): self.app = Flask(__name__, static_url_path='') # server @@ -186,10 +187,10 @@ def __init__( self.ds = MMapRetrievalIndexedDataset(retrieval_index) api = Api(self.app) api.add_resource( - FaissRetrievalResource, '/knn', resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_port], + FaissRetrievalResource, '/knn', resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_ip, query_bert_port], ) - def run(self, url, port=PORT_NUM): + def run(self, url, port=None): threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() @@ -199,15 +200,13 @@ class DynamicRetrievalResource(FaissRetrievalResource): The PUT method is to get KNN tokens, add new chunks, reset index. """ - def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_port, query_bert_port, output_filename): - self.index = index - self.tokenizer = tokenizer + def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): + super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) self.chunk_size = chunk_size self.stride = stride self.pad_id = self.tokenizer.pad_id - self.ds = store + self.ctx_bert_ip = ctx_bert_ip self.ctx_bert_port = ctx_bert_port - self.query_bert_port = query_bert_port self.output_filename = output_filename def put(self): @@ -268,7 +267,7 @@ def add_docs_to_index(self, docs: List[str], add_eos: bool = True): chunk = np_array[i : i + 2 * self.chunk_size] self.ds.add(chunk) chunk_texts.append(self.tokenizer.ids_to_text(chunk)) - emb = request_data(chunk_texts, self.ctx_bert_port) + emb = request_data(chunk_texts, self.ctx_bert_ip, self.ctx_bert_port) emb_data = base64.b64decode(emb.encode()) emb = pickle.loads(emb_data) self.index.add(emb) # add vectors to the index @@ -287,13 +286,15 @@ def __init__( stride: int = 32, faiss_index: str = None, store_file: str = None, + ctx_bert_ip: str = None, ctx_bert_port: int = 0, + query_bert_ip: str = None, query_bert_port: int = 0, output_filename: str = 'dynamic_db', ): self.app = Flask(__name__, static_url_path='') has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") - embedding_dim = request_data({}, ctx_bert_port)['dim'] + embedding_dim = request_data({}, ctx_bert_ip, ctx_bert_port)['dim'] if faiss_index is not None: self.index = faiss.read_index(faiss_index) @@ -335,13 +336,15 @@ def __init__( self.chunk_size, self.stride, self.store, + ctx_bert_ip, ctx_bert_port, + query_bert_ip, query_bert_port, output_filename, ], ) - def run(self, url, port=PORT_NUM_DYN): + def run(self, url, port=None): threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() @@ -359,26 +362,19 @@ def __init__( nprobe: int, retrieval_index: str, tokenizer: TokenizerSpec, - query_bert: str = None, + query_bert_ip: str = None, + query_bert_port: int = None, ): self.updatable = False self.tokenizer = tokenizer ds = MMapRetrievalIndexedDataset(retrieval_index) self.chunk_size = ds.chunk_size pad_id = self.tokenizer.pad_id - query_bert_port = BERT_MODEL_MAP[query_bert] + # query_bert_port = BERT_MODEL_MAP[query_bert] # batch, neighbors, 2*chunk_size self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=ds._index.dtype) * pad_id - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - server = RetrievalServer( - faiss_index, faiss_devices, nprobe, retrieval_index, tokenizer, query_bert_port - ) - server.run("0.0.0.0") - torch.distributed.barrier() - else: - server = RetrievalServer(faiss_index, faiss_devices, nprobe, retrieval_index, tokenizer, query_bert_port) - server.run("0.0.0.0") + self.query_bert_ip = query_bert_ip + self.query_bert_port = query_bert_port def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): if isinstance(query, torch.Tensor): @@ -392,7 +388,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors - result = request_data(data, PORT_NUM) + result = request_data(data, self.query_bert_ip, self.query_bert_port) result = np.array(result) return result @@ -407,52 +403,19 @@ class DynamicFaissRetrievalService(RetrievalService): def __init__( self, - faiss_devices: str, tokenizer: TokenizerSpec, chunk_size: int, - stride: int, - faiss_index: str = None, - store_file: str = None, - ctx_bert: str = None, - query_bert: str = None, - output_filename: str = 'dynamic_db', + service_ip: str, + service_port: int, ): self.updatable = True self.tokenizer = tokenizer self.chunk_size = chunk_size pad_id = self.tokenizer.pad_id - ctx_bert_port = BERT_MODEL_MAP[ctx_bert] - query_bert_port = BERT_MODEL_MAP[query_bert] # batch, neighbors, 2*chunk_size self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=np.int64) * pad_id - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - server = DynamicRetrievalServer( - faiss_devices, - tokenizer, - chunk_size, - stride, - faiss_index, - store_file, - ctx_bert_port, - query_bert_port, - output_filename, - ) - server.run("0.0.0.0") - torch.distributed.barrier() - else: - server = DynamicRetrievalServer( - faiss_devices, - tokenizer, - chunk_size, - stride, - faiss_index, - store_file, - ctx_bert_port, - query_bert_port, - output_filename, - ) - server.run("0.0.0.0") + self.service_ip = service_ip + self.service_port = service_port def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): if isinstance(query, torch.Tensor): @@ -466,7 +429,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors - result = request_data(data, PORT_NUM_DYN) + result = request_data(data, self.service.ip, self.service_port) result = np.array(result) return result @@ -484,7 +447,7 @@ def add_docs_to_index(self, query: List[str], add_eos: bool = True): sentence_list.append(text) query = sentence_list data = {'sentences': query, 'add_eos': add_eos} - return request_data(data, PORT_NUM_DYN) + return request_data(data, self.service.ip, self.service_port) class ComboRetrievalService(RetrievalService): diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/bert_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/bert_service.yaml new file mode 100644 index 000000000000..5f729cd51c8d --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/bert_service.yaml @@ -0,0 +1,13 @@ +name: default # the name of the service +tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer +sentence_bert: # define a few sentence bert models for different retrieval services to use + devices: '0,1,2' + sentence_bert: 'all-mpnet-base-v2' + sentence_bert_batch: 4 + port: 17190 # service port number \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml new file mode 100644 index 000000000000..0dd9e2d84298 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml @@ -0,0 +1,21 @@ +tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer +service: + faiss_devices: '0,1,2' + faiss_index: null # the faiss index file that is used to find KNN + store_file: null # the retrieval service storage to load from file, if null, start from scratch + chunk_size: 64 + stride: 32 + ctx_bert_ip: '0.0.0.0' # the bert service ip to encode the ctx that is used to construct the dynamic retrieval index + ctx_bert_port: 17190 # port number + query_bert_ip: '0.0.0.0' # the bert service to encode the query str + query_bert_port: 17190 # port number + output_filename: 'dynamic_db' # the filename of serialized dynamic retrieval service, used for both Faiss index and data storage + port: 17180 # server port number + +server: False # whether launch the API server \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml new file mode 100644 index 000000000000..15b125b0e874 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml @@ -0,0 +1,15 @@ +tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer +service: + faiss_devices: '0,1,2' + faiss_index: null # the faiss index file that is used to find KNN + nprobe: 100 + retrieval_index: null + query_bert_ip: '0.0.0.0' # the bert model service host ip + query_bert_port: 17190 # the bert model service port number + port: 17179 # server port number diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py new file mode 100644 index 000000000000..24d1f31be668 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +""" + +from nemo.collections.nlp.modules.common.megatron.bert_service import start_sentence_bert_server +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.core.config import hydra_runner + + +def get_tokenizer(args): + tokenizer = get_nmt_tokenizer( + library=args.library, + model_name=args.type, + tokenizer_model=args.model, + vocab_file=args.vocab_file, + merges_file=args.merge_file, + delimiter=args.delimiter, + ) + if not hasattr(tokenizer, "pad_id"): + tokenizer.add_special_tokens({'pad_token': ''}) + elif hasattr(tokenizer, "pad_id") and (tokenizer.pad_id is None or tokenizer.pad_id < 0): + tokenizer.add_special_tokens({'pad_token': ''}) + return tokenizer + + +@hydra_runner(config_path="conf", config_name="bert_service") +def main(cfg) -> None: + tokenizer = get_tokenizer(cfg.tokenizer) + start_sentence_bert_server(cfg.name, + cfg.sentence_bert.devices, + tokenizer, + cfg.sentence_bert.sentence_bert, + cfg.sentence_bert.sentence_bert_batch, + port=cfg.sentence_bert.port + ) + + +if __name__ == "__main__": + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py new file mode 100644 index 000000000000..73aad2df984c --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +""" + +from nemo.collections.nlp.modules.common.megatron.retrieval_service import DynamicRetrievalServer +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.core.config import hydra_runner + + +def get_tokenizer(args): + tokenizer = get_nmt_tokenizer( + library=args.library, + model_name=args.type, + tokenizer_model=args.model, + vocab_file=args.vocab_file, + merges_file=args.merge_file, + delimiter=args.delimiter, + ) + if not hasattr(tokenizer, "pad_id"): + tokenizer.add_special_tokens({'pad_token': ''}) + elif hasattr(tokenizer, "pad_id") and (tokenizer.pad_id is None or tokenizer.pad_id < 0): + tokenizer.add_special_tokens({'pad_token': ''}) + return tokenizer + + +@hydra_runner(config_path="conf", config_name="dynamic_retrieval_service") +def main(cfg) -> None: + tokenizer = get_tokenizer(cfg.tokenizer) + + server = DynamicRetrievalServer( + cfg.service.faiss_devices, + tokenizer, + cfg.service.chunk_size, + cfg.service.stride, + cfg.service.faiss_index, + cfg.service.store_file, + cfg.service.ctx_bert_ip, + cfg.service.ctx_bert_port, + cfg.service.query_bert_ip, + cfg.service.query_bert_port, + cfg.service.output_filename, + ) + server.run("0.0.0.0", cfg.service.port) + + +if __name__ == "__main__": + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py new file mode 100644 index 000000000000..3b1a3bdd93c0 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py @@ -0,0 +1,55 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +""" + +from nemo.collections.nlp.modules.common.megatron.bert_service import start_sentence_bert_server +from nemo.collections.nlp.modules.common.megatron.retrieval_service import RetrievalServer +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.core.config import hydra_runner + + +def get_tokenizer(args): + tokenizer = get_nmt_tokenizer( + library=args.library, + model_name=args.type, + tokenizer_model=args.model, + vocab_file=args.vocab_file, + merges_file=args.merge_file, + delimiter=args.delimiter, + ) + if not hasattr(tokenizer, "pad_id"): + tokenizer.add_special_tokens({'pad_token': ''}) + elif hasattr(tokenizer, "pad_id") and (tokenizer.pad_id is None or tokenizer.pad_id < 0): + tokenizer.add_special_tokens({'pad_token': ''}) + return tokenizer + + +@hydra_runner(config_path="conf", config_name="static_retrieval_service") +def main(cfg) -> None: + tokenizer = get_tokenizer(cfg.tokenizer) + server = RetrievalServer( + cfg.service.faiss_index, + cfg.service.faiss_devices, + cfg.service.nprobe, + cfg.service.retrieval_index, + tokenizer, + cfg.service.query_bert_ip, + cfg.service.query_bert_port, + ) + server.run("0.0.0.0", cfg.service.port) + + +if __name__ == "__main__": + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file From b7c659c2cfe3a567af6063f8b35c44385be9951c Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 31 Jan 2023 23:34:25 +0000 Subject: [PATCH 02/17] refactor service code Signed-off-by: Yi Dong --- .../megatron/retrieval_services/__init__.py | 0 .../{ => retrieval_services}/bert_service.py | 2 +- .../dynamic_retrieve_server.py | 208 ++++++++++ .../retrieval_service.py | 391 ++++-------------- .../static_retrieve_server.py | 134 ++++++ .../megatron/retrieval_services/util.py | 30 ++ .../nlp/modules/common/megatron_web_server.py | 2 +- .../common/retro_inference_strategies.py | 4 +- .../start_bert_service.py | 2 +- .../start_dynamic_retrieval_service.py | 2 +- .../start_static_retrieval_service.py | 3 +- 11 files changed, 468 insertions(+), 310 deletions(-) create mode 100644 nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py rename nemo/collections/nlp/modules/common/megatron/{ => retrieval_services}/bert_service.py (98%) create mode 100644 nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py rename nemo/collections/nlp/modules/common/megatron/{ => retrieval_services}/retrieval_service.py (58%) create mode 100644 nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py create mode 100644 nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/nlp/modules/common/megatron/bert_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/bert_service.py similarity index 98% rename from nemo/collections/nlp/modules/common/megatron/bert_service.py rename to nemo/collections/nlp/modules/common/megatron/retrieval_services/bert_service.py index 0c9f6e07ee2d..b69bf04c769f 100644 --- a/nemo/collections/nlp/modules/common/megatron/bert_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/bert_service.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py new file mode 100644 index 000000000000..df9997f58c6d --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +import pickle +import threading +import time +from typing import List + +import faiss +import numpy as np +import torch +from flask import Flask, jsonify, request +from flask_restful import Api +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieve_server import FaissRetrievalResource +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock + + +class ChunkStore: + """ + ChunkStore maps chunk id to tokens. It is used as an in memory storage for dynamic retrieval DB. + """ + + def __init__(self, chunk_size, pad_id): + self.store = {} + self._count = 0 + self.no_retrieval = np.ones(2 * chunk_size, dtype=np.int64) * pad_id + self.store[-1] = self.no_retrieval + + def add(self, chunk): + self.store[self._count] = chunk + self._count += 1 + + def get_chunk(self, neighbor_id): + return self.store[neighbor_id] + + def reset(self): + self._count = 0 + self.store = {} + self.store[-1] = self.no_retrieval + + +class DynamicRetrievalResource(FaissRetrievalResource): + """ + Dynamic Faiss Retrieval Flask resource. + The PUT method is to get KNN tokens, add new chunks, reset index. + """ + + def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): + super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) + self.chunk_size = chunk_size + self.stride = stride + self.pad_id = self.tokenizer.pad_id + self.ctx_bert_ip = ctx_bert_ip + self.ctx_bert_port = ctx_bert_port + self.output_filename = output_filename + + def put(self): + data = request.get_json() + if 'neighbors' in data: + sentences = data['sentences'] + # do knn query + num_neighbors = data['neighbors'] + with lock: # Need to get lock to keep multiple threads from hitting code + neighbors = self.get_knn(sentences, num_neighbors) + return jsonify(neighbors.tolist()) + elif 'reset' in data: + with lock: # Need to get lock to keep multiple threads from hitting code + self.reset() + return "success" + elif 'index_name' in data: + with lock: + # serialize the index + index = self.index + if hasattr(faiss, 'index_gpu_to_cpu'): + index = faiss.index_gpu_to_cpu(index) + faiss.write_index(index, data['index_name'] + '_' + self.output_filename + '.index') + # save the data + with open(self.output_filename + '.pkl', 'bw') as f: + pickle.dump(self.ds, f) + else: + sentences = data['sentences'] + add_eos = data['add_eos'] + # update the index + with lock: # Need to get lock to keep multiple threads from hitting code + self.add_docs_to_index(sentences, add_eos) + return "success" + + def reset(self): + self.index.reset() + self.ds.reset() + + def add_docs_to_index(self, docs: List[str], add_eos: bool = True): + """ + Add documents to the Faiss index + Args: + docs: List[str], list of documents that is going to be added to the index + add_eos: bool, whether add the eos in the end + """ + for doc in docs: + token_ids = self.tokenizer.text_to_ids(doc) + # append eos in the end + if add_eos: + token_ids.append(self.tokenizer.eos_id) + np_array = np.array(token_ids, dtype=np.int32) + padded_size = self.chunk_size - (len(np_array) % self.chunk_size) + # for retrieval database, added one more chunk in the end as padding + padded_size += self.chunk_size + np_array = np.pad(np_array, (0, padded_size), 'constant', constant_values=self.pad_id) + chunk_texts = [] + for i in range(0, len(np_array), self.stride): + if i + 2 * self.chunk_size <= len(np_array): + chunk = np_array[i : i + 2 * self.chunk_size] + self.ds.add(chunk) + chunk_texts.append(self.tokenizer.ids_to_text(chunk)) + emb = request_data(chunk_texts, self.ctx_bert_ip, self.ctx_bert_port) + emb_data = base64.b64decode(emb.encode()) + emb = pickle.loads(emb_data) + self.index.add(emb) # add vectors to the index + + +class DynamicRetrievalServer(object): + """ + Flask Dynamic Retrieval server, which helps to build dynamic retrieval index. + """ + + def __init__( + self, + faiss_devices: str, + tokenizer: TokenizerSpec, + chunk_size: int = 64, + stride: int = 32, + faiss_index: str = None, + store_file: str = None, + ctx_bert_ip: str = None, + ctx_bert_port: int = 0, + query_bert_ip: str = None, + query_bert_port: int = 0, + output_filename: str = 'dynamic_db', + ): + self.app = Flask(__name__, static_url_path='') + has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") + embedding_dim = request_data({}, ctx_bert_ip, ctx_bert_port)['dim'] + + if faiss_index is not None: + self.index = faiss.read_index(faiss_index) + else: + self.index = faiss.IndexFlatL2(embedding_dim) # build the index + self.pad_id = tokenizer.pad_id + self.chunk_size = chunk_size + self.stride = stride + if store_file is not None: + with open(store_file, 'rb') as f: + self.store = pickle.load(f) + else: + self.store = ChunkStore(chunk_size, self.pad_id) + + if faiss_devices is None or not torch.cuda.is_available(): + device_list = None + else: + device_list = ['cuda:' + str(device) for device in faiss_devices.split(',')] + + if has_gpu and device_list is not None: + beg = time.time() + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + co.usePrecomputed = False + co.shard = True + self.index = faiss.index_cpu_to_all_gpus(self.index, co, ngpu=len(device_list)) + end = time.time() + logging.info(f'convert Faiss db to GPU takes {end - beg} s') + + self.tokenizer = tokenizer + + api = Api(self.app) + api.add_resource( + DynamicRetrievalResource, + '/knn', + resource_class_args=[ + self.index, + self.tokenizer, + self.chunk_size, + self.stride, + self.store, + ctx_bert_ip, + ctx_bert_port, + query_bert_ip, + query_bert_port, + output_filename, + ], + ) + + def run(self, url, port=None): + threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() + diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py similarity index 58% rename from nemo/collections/nlp/modules/common/megatron/retrieval_service.py rename to nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index f982ca4ce381..30ece14be9bb 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,38 +14,30 @@ import abc import base64 -import json import logging import pickle import threading -import time from typing import List, Union import faiss import numpy as np -import requests import torch from flask import Flask, jsonify, request from flask_restful import Api, Resource from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import MMapRetrievalIndexedDataset +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data log = logging.getLogger('retrieval') log.setLevel(logging.ERROR) lock = threading.Lock() -headers = {"Content-Type": "application/json"} PORT_NUM = 17179 PORT_NUM_DYN = 17180 -def request_data(data, ip='localhost', port=None): - resp = requests.put(f'http://{ip}:{port}/knn', data=json.dumps(data), headers=headers) - return resp.json() - - class RetrievalService: """ Abstract class for Retrieval Service. @@ -66,288 +58,6 @@ def add_docs_to_index(self, docs: List[str], add_eos: bool = True): raise NotImplementedError() -class ChunkStore: - """ - ChunkStore maps chunk id to tokens. It is used as an in memory storage for dynamic retrieval DB. - """ - - def __init__(self, chunk_size, pad_id): - self.store = {} - self._count = 0 - self.no_retrieval = np.ones(2 * chunk_size, dtype=np.int64) * pad_id - self.store[-1] = self.no_retrieval - - def add(self, chunk): - self.store[self._count] = chunk - self._count += 1 - - def get_chunk(self, neighbor_id): - return self.store[neighbor_id] - - def reset(self): - self._count = 0 - self.store = {} - self.store[-1] = self.no_retrieval - - -class FaissRetrievalResource(Resource): - """ - Static Faiss Retrieval Flask resource. - The PUT method is to get KNN tokens. - """ - - def __init__( - self, index, tokenizer, ds, query_bert_ip, query_bert_port, - ): - # server - self.index = index - self.tokenizer = tokenizer - self.ds = ds - self.query_bert_ip = query_bert_ip - self.query_bert_port = query_bert_port - - def put(self): - data = request.get_json() - sentences = data['sentences'] - num_neighbors = data['neighbors'] - with lock: # Need to get lock to keep multiple threads from hitting code - neighbors = self.get_knn(sentences, num_neighbors) - return jsonify(neighbors.tolist()) - # check keys - - def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int): - single_sentence = False - if isinstance(query, str): - single_sentence = True - query = [query] - elif isinstance(query, torch.Tensor): - sentence_list = [] - for q in query: - text = self.tokenizer.ids_to_text(q) - sentence_list.append(text) - query = sentence_list - emb = request_data(query, self.query_bert_ip, self.query_bert_port) - emb_data = base64.b64decode(emb.encode()) - emb = pickle.loads(emb_data) - if self.index.ntotal == 0: - # A workaround to fix searching an empty Faiss index - knn = [[-1] * neighbors for i in range(len(emb))] - else: - _, knn = self.index.search(emb, neighbors) - results = [] - for sentence_neighbors in knn: - chunks = [] - for neighbor_chunk_id in sentence_neighbors: - chunk_id = self.ds.get_chunk(neighbor_chunk_id) - chunks.append(chunk_id) - chunks = np.stack(chunks, axis=0).astype(np.int64) - results.append(chunks) - if single_sentence: - # unpack the single sentence input - return results[0] - return np.stack(results, axis=0).astype(np.int64) - - -class RetrievalServer(object): - """ - Flask Retrieval server, which helps to get the KNN tokens given the query chunk - """ - - def __init__( - self, - faiss_index: str, - faiss_devices: str, - nprobe: int, - retrieval_index: str, - tokenizer: TokenizerSpec, - query_bert_ip: str, - query_bert_port: int = None, - ): - self.app = Flask(__name__, static_url_path='') - # server - has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") - - if faiss_devices is None or not torch.cuda.is_available(): - device_list = None - else: - device_list = ['cuda:' + str(device) for device in faiss_devices.split(',')] - - self.index = faiss.read_index(faiss_index) - if has_gpu and device_list is not None: - beg = time.time() - co = faiss.GpuMultipleClonerOptions() - co.useFloat16 = True - co.usePrecomputed = False - co.shard = True - self.index = faiss.index_cpu_to_all_gpus(self.index, co, ngpu=len(device_list)) - end = time.time() - logging.info(f'convert Faiss db to GPU takes {end - beg} s') - self.index.nprobe = nprobe - self.tokenizer = tokenizer - self.ds = MMapRetrievalIndexedDataset(retrieval_index) - api = Api(self.app) - api.add_resource( - FaissRetrievalResource, '/knn', resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_ip, query_bert_port], - ) - - def run(self, url, port=None): - threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() - - -class DynamicRetrievalResource(FaissRetrievalResource): - """ - Dynamic Faiss Retrieval Flask resource. - The PUT method is to get KNN tokens, add new chunks, reset index. - """ - - def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): - super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) - self.chunk_size = chunk_size - self.stride = stride - self.pad_id = self.tokenizer.pad_id - self.ctx_bert_ip = ctx_bert_ip - self.ctx_bert_port = ctx_bert_port - self.output_filename = output_filename - - def put(self): - data = request.get_json() - if 'neighbors' in data: - sentences = data['sentences'] - # do knn query - num_neighbors = data['neighbors'] - with lock: # Need to get lock to keep multiple threads from hitting code - neighbors = self.get_knn(sentences, num_neighbors) - return jsonify(neighbors.tolist()) - elif 'reset' in data: - with lock: # Need to get lock to keep multiple threads from hitting code - self.reset() - return "success" - elif 'index_name' in data: - with lock: - # serialize the index - index = self.index - if hasattr(faiss, 'index_gpu_to_cpu'): - index = faiss.index_gpu_to_cpu(index) - faiss.write_index(index, data['index_name'] + '_' + self.output_filename + '.index') - # save the data - with open(self.output_filename + '.pkl', 'bw') as f: - pickle.dump(self.ds, f) - else: - sentences = data['sentences'] - add_eos = data['add_eos'] - # update the index - with lock: # Need to get lock to keep multiple threads from hitting code - self.add_docs_to_index(sentences, add_eos) - return "success" - - def reset(self): - self.index.reset() - self.ds.reset() - - def add_docs_to_index(self, docs: List[str], add_eos: bool = True): - """ - Add documents to the Faiss index - Args: - docs: List[str], list of documents that is going to be added to the index - add_eos: bool, whether add the eos in the end - """ - for doc in docs: - token_ids = self.tokenizer.text_to_ids(doc) - # append eos in the end - if add_eos: - token_ids.append(self.tokenizer.eos_id) - np_array = np.array(token_ids, dtype=np.int32) - padded_size = self.chunk_size - (len(np_array) % self.chunk_size) - # for retrieval database, added one more chunk in the end as padding - padded_size += self.chunk_size - np_array = np.pad(np_array, (0, padded_size), 'constant', constant_values=self.pad_id) - chunk_texts = [] - for i in range(0, len(np_array), self.stride): - if i + 2 * self.chunk_size <= len(np_array): - chunk = np_array[i : i + 2 * self.chunk_size] - self.ds.add(chunk) - chunk_texts.append(self.tokenizer.ids_to_text(chunk)) - emb = request_data(chunk_texts, self.ctx_bert_ip, self.ctx_bert_port) - emb_data = base64.b64decode(emb.encode()) - emb = pickle.loads(emb_data) - self.index.add(emb) # add vectors to the index - - -class DynamicRetrievalServer(object): - """ - Flask Dynamic Retrieval server, which helps to build dynamic retrieval index. - """ - - def __init__( - self, - faiss_devices: str, - tokenizer: TokenizerSpec, - chunk_size: int = 64, - stride: int = 32, - faiss_index: str = None, - store_file: str = None, - ctx_bert_ip: str = None, - ctx_bert_port: int = 0, - query_bert_ip: str = None, - query_bert_port: int = 0, - output_filename: str = 'dynamic_db', - ): - self.app = Flask(__name__, static_url_path='') - has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") - embedding_dim = request_data({}, ctx_bert_ip, ctx_bert_port)['dim'] - - if faiss_index is not None: - self.index = faiss.read_index(faiss_index) - else: - self.index = faiss.IndexFlatL2(embedding_dim) # build the index - self.pad_id = tokenizer.pad_id - self.chunk_size = chunk_size - self.stride = stride - if store_file is not None: - with open(store_file, 'rb') as f: - self.store = pickle.load(f) - else: - self.store = ChunkStore(chunk_size, self.pad_id) - - if faiss_devices is None or not torch.cuda.is_available(): - device_list = None - else: - device_list = ['cuda:' + str(device) for device in faiss_devices.split(',')] - - if has_gpu and device_list is not None: - beg = time.time() - co = faiss.GpuMultipleClonerOptions() - co.useFloat16 = True - co.usePrecomputed = False - co.shard = True - self.index = faiss.index_cpu_to_all_gpus(self.index, co, ngpu=len(device_list)) - end = time.time() - logging.info(f'convert Faiss db to GPU takes {end - beg} s') - - self.tokenizer = tokenizer - - api = Api(self.app) - api.add_resource( - DynamicRetrievalResource, - '/knn', - resource_class_args=[ - self.index, - self.tokenizer, - self.chunk_size, - self.stride, - self.store, - ctx_bert_ip, - ctx_bert_port, - query_bert_ip, - query_bert_port, - output_filename, - ], - ) - - def run(self, url, port=None): - threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() - - class FaissRetrievalService(RetrievalService): """ Top level static retrieval service class. @@ -357,13 +67,10 @@ class FaissRetrievalService(RetrievalService): def __init__( self, - faiss_index: str, - faiss_devices: str, - nprobe: int, retrieval_index: str, tokenizer: TokenizerSpec, - query_bert_ip: str = None, - query_bert_port: int = None, + service_ip: str = None, + service_port: int = None, ): self.updatable = False self.tokenizer = tokenizer @@ -373,8 +80,8 @@ def __init__( # query_bert_port = BERT_MODEL_MAP[query_bert] # batch, neighbors, 2*chunk_size self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=ds._index.dtype) * pad_id - self.query_bert_ip = query_bert_ip - self.query_bert_port = query_bert_port + self.service_ip = service_ip + self.service_port = service_port def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): if isinstance(query, torch.Tensor): @@ -388,7 +95,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors - result = request_data(data, self.query_bert_ip, self.query_bert_port) + result = request_data(data, self.service_ip, self.service_port) result = np.array(result) return result @@ -429,7 +136,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors - result = request_data(data, self.service.ip, self.service_port) + result = request_data(data, self.service_ip, self.service_port) result = np.array(result) return result @@ -447,7 +154,7 @@ def add_docs_to_index(self, query: List[str], add_eos: bool = True): sentence_list.append(text) query = sentence_list data = {'sentences': query, 'add_eos': add_eos} - return request_data(data, self.service.ip, self.service_port) + return request_data(data, self.service_ip, self.service_port) class ComboRetrievalService(RetrievalService): @@ -505,3 +212,83 @@ def add_docs_to_index(self, query: List[str], add_eos: bool = True): if service.updatable: service.add_docs_to_index(query, add_eos) return output + + +class ComboRetrievalResource(Resource): + """ + Dynamic Faiss Retrieval Flask resource. + The PUT method is to get KNN tokens, add new chunks, reset index. + """ + + def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): + super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) + self.chunk_size = chunk_size + self.stride = stride + self.pad_id = self.tokenizer.pad_id + self.ctx_bert_ip = ctx_bert_ip + self.ctx_bert_port = ctx_bert_port + self.output_filename = output_filename + + def put(self): + data = request.get_json() + if 'neighbors' in data: + sentences = data['sentences'] + # do knn query + num_neighbors = data['neighbors'] + with lock: # Need to get lock to keep multiple threads from hitting code + neighbors = self.get_knn(sentences, num_neighbors) + return jsonify(neighbors.tolist()) + elif 'reset' in data: + with lock: # Need to get lock to keep multiple threads from hitting code + self.reset() + return "success" + elif 'index_name' in data: + with lock: + # serialize the index + index = self.index + if hasattr(faiss, 'index_gpu_to_cpu'): + index = faiss.index_gpu_to_cpu(index) + faiss.write_index(index, data['index_name'] + '_' + self.output_filename + '.index') + # save the data + with open(self.output_filename + '.pkl', 'bw') as f: + pickle.dump(self.ds, f) + else: + sentences = data['sentences'] + add_eos = data['add_eos'] + # update the index + with lock: # Need to get lock to keep multiple threads from hitting code + self.add_docs_to_index(sentences, add_eos) + return "success" + + def reset(self): + self.index.reset() + self.ds.reset() + + def add_docs_to_index(self, docs: List[str], add_eos: bool = True): + """ + Add documents to the Faiss index + Args: + docs: List[str], list of documents that is going to be added to the index + add_eos: bool, whether add the eos in the end + """ + for doc in docs: + token_ids = self.tokenizer.text_to_ids(doc) + # append eos in the end + if add_eos: + token_ids.append(self.tokenizer.eos_id) + np_array = np.array(token_ids, dtype=np.int32) + padded_size = self.chunk_size - (len(np_array) % self.chunk_size) + # for retrieval database, added one more chunk in the end as padding + padded_size += self.chunk_size + np_array = np.pad(np_array, (0, padded_size), 'constant', constant_values=self.pad_id) + chunk_texts = [] + for i in range(0, len(np_array), self.stride): + if i + 2 * self.chunk_size <= len(np_array): + chunk = np_array[i : i + 2 * self.chunk_size] + self.ds.add(chunk) + chunk_texts.append(self.tokenizer.ids_to_text(chunk)) + emb = request_data(chunk_texts, self.ctx_bert_ip, self.ctx_bert_port) + emb_data = base64.b64decode(emb.encode()) + emb = pickle.loads(emb_data) + self.index.add(emb) # add vectors to the index + diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py new file mode 100644 index 000000000000..f5c4065b626d --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py @@ -0,0 +1,134 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +import pickle +import threading +import time +from typing import List, Union + +import faiss +import numpy as np +import torch +from flask import Flask, jsonify, request +from flask_restful import Api, Resource + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import MMapRetrievalIndexedDataset +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock + + +class FaissRetrievalResource(Resource): + """ + Static Faiss Retrieval Flask resource. + The PUT method is to get KNN tokens. + """ + + def __init__( + self, index, tokenizer, ds, query_bert_ip, query_bert_port, + ): + # server + self.index = index + self.tokenizer = tokenizer + self.ds = ds + self.query_bert_ip = query_bert_ip + self.query_bert_port = query_bert_port + + def put(self): + data = request.get_json() + sentences = data['sentences'] + num_neighbors = data['neighbors'] + with lock: # Need to get lock to keep multiple threads from hitting code + neighbors = self.get_knn(sentences, num_neighbors) + return jsonify(neighbors.tolist()) + # check keys + + def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int): + single_sentence = False + if isinstance(query, str): + single_sentence = True + query = [query] + elif isinstance(query, torch.Tensor): + sentence_list = [] + for q in query: + text = self.tokenizer.ids_to_text(q) + sentence_list.append(text) + query = sentence_list + emb = request_data(query, self.query_bert_ip, self.query_bert_port) + emb_data = base64.b64decode(emb.encode()) + emb = pickle.loads(emb_data) + if self.index.ntotal == 0: + # A workaround to fix searching an empty Faiss index + knn = [[-1] * neighbors for i in range(len(emb))] + else: + _, knn = self.index.search(emb, neighbors) + results = [] + for sentence_neighbors in knn: + chunks = [] + for neighbor_chunk_id in sentence_neighbors: + chunk_id = self.ds.get_chunk(neighbor_chunk_id) + chunks.append(chunk_id) + chunks = np.stack(chunks, axis=0).astype(np.int64) + results.append(chunks) + if single_sentence: + # unpack the single sentence input + return results[0] + return np.stack(results, axis=0).astype(np.int64) + + +class RetrievalServer(object): + """ + Flask Retrieval server, which helps to get the KNN tokens given the query chunk + """ + + def __init__( + self, + faiss_index: str, + faiss_devices: str, + nprobe: int, + retrieval_index: str, + tokenizer: TokenizerSpec, + query_bert_ip: str, + query_bert_port: int = None, + ): + self.app = Flask(__name__, static_url_path='') + # server + has_gpu = torch.cuda.is_available() and hasattr(faiss, "index_gpu_to_cpu") + + if faiss_devices is None or not torch.cuda.is_available(): + device_list = None + else: + device_list = ['cuda:' + str(device) for device in faiss_devices.split(',')] + + self.index = faiss.read_index(faiss_index) + if has_gpu and device_list is not None: + beg = time.time() + co = faiss.GpuMultipleClonerOptions() + co.useFloat16 = True + co.usePrecomputed = False + co.shard = True + self.index = faiss.index_cpu_to_all_gpus(self.index, co, ngpu=len(device_list)) + end = time.time() + logging.info(f'convert Faiss db to GPU takes {end - beg} s') + self.index.nprobe = nprobe + self.tokenizer = tokenizer + self.ds = MMapRetrievalIndexedDataset(retrieval_index) + api = Api(self.app) + api.add_resource( + FaissRetrievalResource, '/knn', resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_ip, query_bert_port], + ) + + def run(self, url, port=None): + threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py new file mode 100644 index 000000000000..2e9400bb57d1 --- /dev/null +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import json +import threading + + +headers = {"Content-Type": "application/json"} + +lock = threading.Lock() + +__all__ = ["request_data", "lock"] + + +def request_data(data, ip='localhost', port=None): + resp = requests.put(f'http://{ip}:{port}/knn', data=json.dumps(data), headers=headers) + return resp.json() + diff --git a/nemo/collections/nlp/modules/common/megatron_web_server.py b/nemo/collections/nlp/modules/common/megatron_web_server.py index 8df5a502a974..bd1d06107716 100644 --- a/nemo/collections/nlp/modules/common/megatron_web_server.py +++ b/nemo/collections/nlp/modules/common/megatron_web_server.py @@ -17,7 +17,7 @@ import gradio as gr import requests -from nemo.collections.nlp.modules.common.megatron.retrieval_service import PORT_NUM_DYN +from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import PORT_NUM_DYN PORT_NUM = 5555 headers = {"Content-Type": "application/json"} diff --git a/nemo/collections/nlp/modules/common/retro_inference_strategies.py b/nemo/collections/nlp/modules/common/retro_inference_strategies.py index 0f86d88b16e2..fa350ecdbc97 100644 --- a/nemo/collections/nlp/modules/common/retro_inference_strategies.py +++ b/nemo/collections/nlp/modules/common/retro_inference_strategies.py @@ -21,8 +21,8 @@ import torch.distributed as dist from nemo.collections.nlp.modules.common.lm_utils import pad_batch -from nemo.collections.nlp.modules.common.megatron.bert_service import start_sentence_bert_server -from nemo.collections.nlp.modules.common.megatron.retrieval_service import ( +from nemo.collections.nlp.modules.common.megatron.retrieval_services.bert_service import start_sentence_bert_server +from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import ( ComboRetrievalService, DynamicFaissRetrievalService, FaissRetrievalService, diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py index 24d1f31be668..43a0f68c432b 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py @@ -14,7 +14,7 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.bert_service import start_sentence_bert_server +from nemo.collections.nlp.modules.common.megatron.retrieval_services.bert_service import start_sentence_bert_server from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py index 73aad2df984c..c362f1bb095a 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py @@ -14,7 +14,7 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.retrieval_service import DynamicRetrievalServer +from nemo.collections.nlp.modules.common.megatron.retrieval_services.dynamic_retrieve_server import DynamicRetrievalServer from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py index 3b1a3bdd93c0..a40766a7e85c 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py @@ -14,8 +14,7 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.bert_service import start_sentence_bert_server -from nemo.collections.nlp.modules.common.megatron.retrieval_service import RetrievalServer +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieve_server import RetrievalServer from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner From b8886a57debd0b20ad70ddfb4b6e0e4c2fee7a87 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Tue, 31 Jan 2023 23:41:14 +0000 Subject: [PATCH 03/17] fix name Signed-off-by: Yi Dong --- .../megatron/retrieval_services/combo_retrieval_server.py | 0 .../{dynamic_retrieve_server.py => dynamic_retrieval_server.py} | 2 +- .../{static_retrieve_server.py => static_retrieval_server.py} | 0 .../service_launch_scripts/start_dynamic_retrieval_service.py | 2 +- .../service_launch_scripts/start_static_retrieval_service.py | 2 +- 5 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py rename nemo/collections/nlp/modules/common/megatron/retrieval_services/{dynamic_retrieve_server.py => dynamic_retrieval_server.py} (99%) rename nemo/collections/nlp/modules/common/megatron/retrieval_services/{static_retrieve_server.py => static_retrieval_server.py} (100%) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py similarity index 99% rename from nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py rename to nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py index df9997f58c6d..33c388a2e569 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieve_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py @@ -25,7 +25,7 @@ from flask import Flask, jsonify, request from flask_restful import Api from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieve_server import FaissRetrievalResource +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import FaissRetrievalResource from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py similarity index 100% rename from nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieve_server.py rename to nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py index c362f1bb095a..60aaecbc35c8 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py @@ -14,7 +14,7 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.retrieval_services.dynamic_retrieve_server import DynamicRetrievalServer +from nemo.collections.nlp.modules.common.megatron.retrieval_services.dynamic_retrieval_server import DynamicRetrievalServer from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py index a40766a7e85c..936ebebbdc2f 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py @@ -14,7 +14,7 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieve_server import RetrievalServer +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import RetrievalServer from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner From b2312ead2269b2f190a0a0c2b39d8356ae097c5a Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 21:38:49 +0000 Subject: [PATCH 04/17] add combo server Signed-off-by: Yi Dong --- .../combo_retrieval_server.py | 195 ++++++++++++++++++ .../dynamic_retrieval_server.py | 9 + .../retrieval_services/retrieval_service.py | 59 ++++-- .../static_retrieval_server.py | 6 + .../conf/dynamic_retrieval_service.yaml | 4 +- .../conf/static_retrieval_service.yaml | 2 +- 6 files changed, 252 insertions(+), 23 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py index e69de29bb2d1..67a04be02d3a 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import logging +import pickle +import threading +import time +from typing import List, Union + +import faiss +import numpy as np +import torch +from flask import Flask, jsonify, request +from flask_restful import Api +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import DynamicFaissRetrievalService, FaissRetrievalService +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import FaissRetrievalResource +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock +from flask_restful import Api, Resource + + +weights = None + + +class ComboRetrievalResource(Resource): + """ + Combo Faiss Retrieval Flask resource. + The PUT method is to get KNN tokens, add new chunks, reset index. + """ + + def __init__(self, retrieval_services, weight_container): + self.retrieval_services = retrieval_services + self.updatable = any([service.updatable for service in retrieval_services]) + + self.weight_container = weight_container + weights = np.array(weight_container[0]) + # normalize the weights + weights = weights / weights.sum() + self.weight_container[0] = weights + + self.chunk_size = self.retrieval_services[0].chunk_size + + def put(self): + data = request.get_json() + if 'neighbors' in data: + sentences = data['sentences'] + # do knn query + num_neighbors = data['neighbors'] + with lock: # Need to get lock to keep multiple threads from hitting code + neighbors = self.get_knn(sentences, num_neighbors) + return jsonify(neighbors.tolist()) + elif 'reset' in data: + with lock: # Need to get lock to keep multiple threads from hitting code + self.reset() + return "success" + elif 'update_weight' in data: + with lock: + self.update_weights(data['update_weight']) + elif 'index_name' in data: + with lock: + # serialize the index + index = self.index + if hasattr(faiss, 'index_gpu_to_cpu'): + index = faiss.index_gpu_to_cpu(index) + faiss.write_index(index, data['index_name'] + '_' + self.output_filename + '.index') + # save the data + with open(self.output_filename + '.pkl', 'bw') as f: + pickle.dump(self.ds, f) + else: + sentences = data['sentences'] + add_eos = data['add_eos'] + # update the index + with lock: # Need to get lock to keep multiple threads from hitting code + self.add_docs_to_index(sentences, add_eos) + return "success" + + def reset(self): + output = 'success' + if not self.updatable: + return 'no dynamic service, no action is performed' + for i, service in enumerate(self.retrieval_services): + if service.updatable: + service.reset() + return output + + def update_weights(self, weights): + weights = np.array(weights) + # normalize the weights + weights = weights / weights.sum() + self.weight_container[0] = weights + + def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): + weights = self.weight_container[0] + if neighbors == 0: + return self.retrieval_services[0].get_knn(query, 0) + total_neighbors = 0 + results = [] + for i, service in enumerate(self.retrieval_services): + k = int(neighbors * weights[i]) + if i == len(self.retrieval_services) - 1: + k = neighbors - total_neighbors + total_neighbors += k + if k == 0: + # empty, skip it + continue + result = service.get_knn(query, k) + results.append(result) + return np.concatenate(results, axis=1) + + def add_docs_to_index(self, query: List[str], add_eos: bool = True): + """ + Add documents to the Faiss index + Args: + docs: List[str], list of documents that is going to be added to the index + add_eos: bool, whether add the eos in the end + """ + output = 'success' + if not self.updatable: + if not self.updatable: + return 'no dynamic service, no action is performed' + for i, service in enumerate(self.retrieval_services): + if service.updatable: + service.add_docs_to_index(query, add_eos) + return output + + def write_index(self, index_name: str): + """ + write the dynamic index into a file + Args: + index_name: str, index name + """ + output = 'success' + if not self.updatable: + if not self.updatable: + return 'no dynamic service, no action is performed' + for i, service in enumerate(self.retrieval_services): + if service.updatable: + service.write_index(index_name) + return output + + +class ComboRetrievalServer(object): + """ + Flask Combo Retrieval server, which helps to aggregate different retrieval services + """ + + def __init__( + self, + tokenizer: TokenizerSpec, + services_cfg: list, + ): + services = [] + weights = [] + for service_cfg in services_cfg: + weights.append(service_cfg.weight) + if service_cfg.type == 'FaissRetrievalService': + service = FaissRetrievalService( + tokenizer=tokenizer, + service_ip=service_cfg.service_ip, + service_port=service_cfg.service_port) + elif service_cfg.type == 'DynamicFaissRetrievalService': + service = DynamicFaissRetrievalService( + tokenizer=tokenizer, + service_ip=service_cfg.service_ip, + service_port=service_cfg.service_port) + else: + raise ValueError(f'Unsupported retrieval service {service_cfg.type}') + services.append(service) + self.weight_container = [weights] + self.tokenizer = tokenizer + + api = Api(self.app) + api.add_resource( + ComboRetrievalResource, + '/knn', + resource_class_args=[ + services, + self.weight_container, + ], + ) + + def run(self, url, port=None): + threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py index 33c388a2e569..407ceb9ab2be 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py @@ -18,6 +18,7 @@ import threading import time from typing import List +from collections import namedtuple import faiss import numpy as np @@ -29,6 +30,10 @@ from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock +# define this type to mimic the indexed dataset +DType = namedtuple('DType', ['dtype']) + + class ChunkStore: """ ChunkStore maps chunk id to tokens. It is used as an in memory storage for dynamic retrieval DB. @@ -38,7 +43,10 @@ def __init__(self, chunk_size, pad_id): self.store = {} self._count = 0 self.no_retrieval = np.ones(2 * chunk_size, dtype=np.int64) * pad_id + self.chunk_size = chunk_size self.store[-1] = self.no_retrieval + field = DType(dtype=np.int64) + self._index = field def add(self, chunk): self.store[self._count] = chunk @@ -91,6 +99,7 @@ def put(self): # save the data with open(self.output_filename + '.pkl', 'bw') as f: pickle.dump(self.ds, f) + return "success" else: sentences = data['sentences'] add_eos = data['add_eos'] diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index 30ece14be9bb..8e5e4b9b7627 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -45,6 +45,12 @@ class RetrievalService: @abc.abstractmethod def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int): + """Get K-nearest neighbor chunks based on the input query + + Args: + query (Union[List[str], str, torch.Tensor]): query str, list of str or token ids in torch.Tensor type + neighbors (int): number of neighbors to query + """ pass @abc.abstractmethod @@ -60,30 +66,29 @@ def add_docs_to_index(self, docs: List[str], add_eos: bool = True): class FaissRetrievalService(RetrievalService): """ - Top level static retrieval service class. - It starts the server at rank 0 worker, currently doesn't support multiple nodes yet. + Static retrieval service client class. It implements the retrieval services interface, has a simple client to do KNN queries. """ def __init__( self, - retrieval_index: str, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, ): self.updatable = False self.tokenizer = tokenizer - ds = MMapRetrievalIndexedDataset(retrieval_index) - self.chunk_size = ds.chunk_size - pad_id = self.tokenizer.pad_id - # query_bert_port = BERT_MODEL_MAP[query_bert] - # batch, neighbors, 2*chunk_size - self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=ds._index.dtype) * pad_id self.service_ip = service_ip self.service_port = service_port def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): + """Get K-nearest neighbor chunks based on the input query + + Args: + query (Union[List[str], str, torch.Tensor]): query str, list of str or token ids in torch.Tensor type + neighbors (int): number of neighbors to query + """ + if isinstance(query, torch.Tensor): sentence_list = [] for q in query: @@ -102,8 +107,7 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): class DynamicFaissRetrievalService(RetrievalService): """ - Top level dynamic retrieval service class. - It starts the server at rank 0 worker, currently doesn't support multiple nodes yet. + Dynamic retrieval service client class. It implements the retrieval services interface, has a simple client to add, reset and query the dynamic retrieval index. """ @@ -111,29 +115,28 @@ class DynamicFaissRetrievalService(RetrievalService): def __init__( self, tokenizer: TokenizerSpec, - chunk_size: int, service_ip: str, service_port: int, ): self.updatable = True self.tokenizer = tokenizer - self.chunk_size = chunk_size - pad_id = self.tokenizer.pad_id - # batch, neighbors, 2*chunk_size - self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=np.int64) * pad_id self.service_ip = service_ip self.service_port = service_port def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): + """Get K-nearest neighbor chunks based on the input query + + Args: + query (Union[List[str], str, torch.Tensor]): query str, list of str or token ids in torch.Tensor type + neighbors (int): number of neighbors to query + """ + if isinstance(query, torch.Tensor): sentence_list = [] for q in query: text = self.tokenizer.ids_to_text(q) sentence_list.append(text) query = sentence_list - if neighbors == 0: - # use padding - return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors result = request_data(data, self.service_ip, self.service_port) @@ -156,6 +159,24 @@ def add_docs_to_index(self, query: List[str], add_eos: bool = True): data = {'sentences': query, 'add_eos': add_eos} return request_data(data, self.service_ip, self.service_port) + def write_index(self, index_name: str): + """ + Write the dynamic index and document storage into file + Args: + index_name: str, the index name used for the file name + """ + data = {'index_name': index_name} + return request_data(data, self.service_ip, self.service_port) + + def reset(self, index_name: str): + """ + Write the dynamic index and document storage into file + Args: + index_name: str, the index name used for the file name + """ + data = {'reset': None} + return request_data(data, self.service_ip, self.service_port) + class ComboRetrievalService(RetrievalService): """ diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py index f5c4065b626d..276a6f3e796a 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py @@ -45,6 +45,9 @@ def __init__( self.ds = ds self.query_bert_ip = query_bert_ip self.query_bert_port = query_bert_port + self.chunk_size = ds.chunk_size + pad_id = self.tokenizer.pad_id + self.no_retrieval = np.ones((1, 1, 2 * self.chunk_size), dtype=ds._index.dtype) * pad_id def put(self): data = request.get_json() @@ -56,6 +59,9 @@ def put(self): # check keys def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors: int): + if neighbors == 0: + # use padding + return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) single_sentence = False if isinstance(query, str): single_sentence = True diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml index 0dd9e2d84298..f6015e33e24d 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/dynamic_retrieval_service.yaml @@ -16,6 +16,4 @@ service: query_bert_ip: '0.0.0.0' # the bert service to encode the query str query_bert_port: 17190 # port number output_filename: 'dynamic_db' # the filename of serialized dynamic retrieval service, used for both Faiss index and data storage - port: 17180 # server port number - -server: False # whether launch the API server \ No newline at end of file + port: 17180 # server port number \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml index 15b125b0e874..6bd60afb0c33 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/static_retrieval_service.yaml @@ -12,4 +12,4 @@ service: retrieval_index: null query_bert_ip: '0.0.0.0' # the bert model service host ip query_bert_port: 17190 # the bert model service port number - port: 17179 # server port number + port: 17179 # server port number \ No newline at end of file From ebf1040bac4ba2e8626e7f559328792ff945ab27 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 21:39:12 +0000 Subject: [PATCH 05/17] added combo files Signed-off-by: Yi Dong --- .../conf/combo_retrieval_service.yaml | 18 +++++++ .../start_combo_retrieval_service.py | 50 +++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/combo_retrieval_service.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/combo_retrieval_service.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/combo_retrieval_service.yaml new file mode 100644 index 000000000000..92db18df2568 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/combo_retrieval_service.yaml @@ -0,0 +1,18 @@ +tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer +service: + child_services: + - type: 'FaissRetrievalService' + service_ip: '0.0.0.0' + service_port: 17179 + weight: 0.5 # initial weight for child service + - type: 'DynamicFaissRetrievalService' + service_ip: '0.0.0.0' + service_port: 17180 + weight: 0.5 # initial weight for child service + port: 17181 # server port number \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py new file mode 100644 index 000000000000..02cf21d678b1 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +""" + +from nemo.collections.nlp.modules.common.megatron.retrieval_services.combo_retrieval_server import ComboRetrievalServer +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.core.config import hydra_runner + + +def get_tokenizer(args): + tokenizer = get_nmt_tokenizer( + library=args.library, + model_name=args.type, + tokenizer_model=args.model, + vocab_file=args.vocab_file, + merges_file=args.merge_file, + delimiter=args.delimiter, + ) + if not hasattr(tokenizer, "pad_id"): + tokenizer.add_special_tokens({'pad_token': ''}) + elif hasattr(tokenizer, "pad_id") and (tokenizer.pad_id is None or tokenizer.pad_id < 0): + tokenizer.add_special_tokens({'pad_token': ''}) + return tokenizer + + +@hydra_runner(config_path="conf", config_name="combo_retrieval_service") +def main(cfg) -> None: + tokenizer = get_tokenizer(cfg.tokenizer) + + server = ComboRetrievalServer( + tokenizer, + cfg.service.child_services + ) + server.run("0.0.0.0", cfg.service.port) + + +if __name__ == "__main__": + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file From 2a501957a3e662e7493c94680f9ba676cad7d4cb Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 22:00:58 +0000 Subject: [PATCH 06/17] fix the bug Signed-off-by: Yi Dong --- .../megatron/retrieval_services/combo_retrieval_server.py | 3 +-- .../common/megatron/retrieval_services/retrieval_service.py | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py index 67a04be02d3a..060b3d5096c0 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py @@ -50,8 +50,6 @@ def __init__(self, retrieval_services, weight_container): weights = weights / weights.sum() self.weight_container[0] = weights - self.chunk_size = self.retrieval_services[0].chunk_size - def put(self): data = request.get_json() if 'neighbors' in data: @@ -161,6 +159,7 @@ def __init__( tokenizer: TokenizerSpec, services_cfg: list, ): + self.app = Flask(__name__, static_url_path='') services = [] weights = [] for service_cfg in services_cfg: diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index 8e5e4b9b7627..e6d0bde92a68 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -95,9 +95,6 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): text = self.tokenizer.ids_to_text(q) sentence_list.append(text) query = sentence_list - if neighbors == 0: - # use padding - return np.repeat(self.no_retrieval, len(query), 0).astype(np.int64) data = {'sentences': query} data['neighbors'] = neighbors result = request_data(data, self.service_ip, self.service_port) @@ -168,7 +165,7 @@ def write_index(self, index_name: str): data = {'index_name': index_name} return request_data(data, self.service_ip, self.service_port) - def reset(self, index_name: str): + def reset(self): """ Write the dynamic index and document storage into file Args: From f4c97f160af68ca12f2ebefe6818b2b69f6baf14 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 22:57:54 +0000 Subject: [PATCH 07/17] add retrieval service Signed-off-by: Yi Dong --- .../combo_retrieval_server.py | 1 + .../retrieval_services/retrieval_service.py | 177 +----------------- 2 files changed, 10 insertions(+), 168 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py index 060b3d5096c0..09e44d85d5b7 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py @@ -66,6 +66,7 @@ def put(self): elif 'update_weight' in data: with lock: self.update_weights(data['update_weight']) + return "success" elif 'index_name' in data: with lock: # serialize the index diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index e6d0bde92a68..47b54874d903 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -13,20 +13,14 @@ # limitations under the License. import abc -import base64 import logging -import pickle import threading from typing import List, Union -import faiss import numpy as np import torch -from flask import Flask, jsonify, request -from flask_restful import Api, Resource from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import MMapRetrievalIndexedDataset from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data log = logging.getLogger('retrieval') @@ -102,44 +96,13 @@ def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): return result -class DynamicFaissRetrievalService(RetrievalService): +class DynamicFaissRetrievalService(FaissRetrievalService): """ Dynamic retrieval service client class. It implements the retrieval services interface, has a simple client to add, reset and query the dynamic retrieval index. """ - def __init__( - self, - tokenizer: TokenizerSpec, - service_ip: str, - service_port: int, - ): - self.updatable = True - self.tokenizer = tokenizer - self.service_ip = service_ip - self.service_port = service_port - - def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): - """Get K-nearest neighbor chunks based on the input query - - Args: - query (Union[List[str], str, torch.Tensor]): query str, list of str or token ids in torch.Tensor type - neighbors (int): number of neighbors to query - """ - - if isinstance(query, torch.Tensor): - sentence_list = [] - for q in query: - text = self.tokenizer.ids_to_text(q) - sentence_list.append(text) - query = sentence_list - data = {'sentences': query} - data['neighbors'] = neighbors - result = request_data(data, self.service_ip, self.service_port) - result = np.array(result) - return result - def add_docs_to_index(self, query: List[str], add_eos: bool = True): """ Add documents to the Faiss index @@ -175,138 +138,16 @@ def reset(self): return request_data(data, self.service_ip, self.service_port) -class ComboRetrievalService(RetrievalService): +class ComboRetrievalService(DynamicFaissRetrievalService): """ - Top level retrieval service class. - It combines other retrieval services as a combo retrieval service. - It uses `weights` to determine the number of neighbors for each of the retrieval service members. + Combo retrieval service client class. + It implements the retrieval services interface, has a simple client to add, reset, query, update weights """ - def __init__(self, retrieval_services, weights, store): - self.retrieval_services = retrieval_services - self.updatable = any([service.updatable for service in retrieval_services]) - self.store = store - weights = np.array(weights) - # normalize the weights - weights = weights / weights.sum() - store.set('weights', pickle.dumps(weights)) - self.chunk_size = self.retrieval_services[0].chunk_size - - def update_weights(self, weights): - weights = np.array(weights) - # normalize the weights - weights = weights / weights.sum() - self.store.set('weights', pickle.dumps(weights)) - - def get_knn(self, query: Union[List[str], str, torch.Tensor], neighbors): - weights = pickle.loads(self.store.get('weights')) - if neighbors == 0: - return self.retrieval_services[0].get_knn(query, 0) - total_neighbors = 0 - results = [] - for i, service in enumerate(self.retrieval_services): - k = int(neighbors * weights[i]) - if i == len(self.retrieval_services) - 1: - k = neighbors - total_neighbors - total_neighbors += k - if k == 0: - # empty, skip it - continue - result = service.get_knn(query, k) - results.append(result) - return np.concatenate(results, axis=1) - - def add_docs_to_index(self, query: List[str], add_eos: bool = True): - """ - Add documents to the Faiss index + def update_weights(self, weights: List[float]): + """ update the weights between the children services Args: - docs: List[str], list of documents that is going to be added to the index - add_eos: bool, whether add the eos in the end + weights (List[float]): weights for children services """ - output = 'success' - if not self.updatable: - return output - for i, service in enumerate(self.retrieval_services): - if service.updatable: - service.add_docs_to_index(query, add_eos) - return output - - -class ComboRetrievalResource(Resource): - """ - Dynamic Faiss Retrieval Flask resource. - The PUT method is to get KNN tokens, add new chunks, reset index. - """ - - def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): - super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) - self.chunk_size = chunk_size - self.stride = stride - self.pad_id = self.tokenizer.pad_id - self.ctx_bert_ip = ctx_bert_ip - self.ctx_bert_port = ctx_bert_port - self.output_filename = output_filename - - def put(self): - data = request.get_json() - if 'neighbors' in data: - sentences = data['sentences'] - # do knn query - num_neighbors = data['neighbors'] - with lock: # Need to get lock to keep multiple threads from hitting code - neighbors = self.get_knn(sentences, num_neighbors) - return jsonify(neighbors.tolist()) - elif 'reset' in data: - with lock: # Need to get lock to keep multiple threads from hitting code - self.reset() - return "success" - elif 'index_name' in data: - with lock: - # serialize the index - index = self.index - if hasattr(faiss, 'index_gpu_to_cpu'): - index = faiss.index_gpu_to_cpu(index) - faiss.write_index(index, data['index_name'] + '_' + self.output_filename + '.index') - # save the data - with open(self.output_filename + '.pkl', 'bw') as f: - pickle.dump(self.ds, f) - else: - sentences = data['sentences'] - add_eos = data['add_eos'] - # update the index - with lock: # Need to get lock to keep multiple threads from hitting code - self.add_docs_to_index(sentences, add_eos) - return "success" - - def reset(self): - self.index.reset() - self.ds.reset() - - def add_docs_to_index(self, docs: List[str], add_eos: bool = True): - """ - Add documents to the Faiss index - Args: - docs: List[str], list of documents that is going to be added to the index - add_eos: bool, whether add the eos in the end - """ - for doc in docs: - token_ids = self.tokenizer.text_to_ids(doc) - # append eos in the end - if add_eos: - token_ids.append(self.tokenizer.eos_id) - np_array = np.array(token_ids, dtype=np.int32) - padded_size = self.chunk_size - (len(np_array) % self.chunk_size) - # for retrieval database, added one more chunk in the end as padding - padded_size += self.chunk_size - np_array = np.pad(np_array, (0, padded_size), 'constant', constant_values=self.pad_id) - chunk_texts = [] - for i in range(0, len(np_array), self.stride): - if i + 2 * self.chunk_size <= len(np_array): - chunk = np_array[i : i + 2 * self.chunk_size] - self.ds.add(chunk) - chunk_texts.append(self.tokenizer.ids_to_text(chunk)) - emb = request_data(chunk_texts, self.ctx_bert_ip, self.ctx_bert_port) - emb_data = base64.b64decode(emb.encode()) - emb = pickle.loads(emb_data) - self.index.add(emb) # add vectors to the index - + data = {"update_weight": weights} + return request_data(data, self.service_ip, self.service_port) From f50fb80c4c0e152991f828261aabcc97dbb9eaa3 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 23:06:53 +0000 Subject: [PATCH 08/17] fix updatable flag Signed-off-by: Yi Dong --- .../retrieval_services/retrieval_service.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index 47b54874d903..b01f331b45a4 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -102,6 +102,14 @@ class DynamicFaissRetrievalService(FaissRetrievalService): It implements the retrieval services interface, has a simple client to add, reset and query the dynamic retrieval index. """ + def __init__( + self, + tokenizer: TokenizerSpec, + service_ip: str = None, + service_port: int = None, + ): + super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port) + self.updatable = True def add_docs_to_index(self, query: List[str], add_eos: bool = True): """ @@ -143,6 +151,13 @@ class ComboRetrievalService(DynamicFaissRetrievalService): Combo retrieval service client class. It implements the retrieval services interface, has a simple client to add, reset, query, update weights """ + def __init__( + self, + tokenizer: TokenizerSpec, + service_ip: str = None, + service_port: int = None, + ): + super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port) def update_weights(self, weights: List[float]): """ update the weights between the children services From a79e92ea7dfcbdcd65916732997516305dff6480 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Wed, 1 Feb 2023 23:34:01 +0000 Subject: [PATCH 09/17] working example Signed-off-by: Yi Dong --- .../conf/megatron_retro_inference.yaml | 33 ++----------------- .../common/retro_inference_strategies.py | 29 +++------------- 2 files changed, 7 insertions(+), 55 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml index 29fb11623b1f..3160809d54d8 100644 --- a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml @@ -39,36 +39,9 @@ retrieval_service: frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once store_retrieved: False # whether store the retrieved documents, so it can be checked - weights: [0.5, 0.5] # weight for different retrieval services - sentence_bert: # define a few sentence bert models for different retrieval services to use - default: - devices: '0,1,2' - sentence_bert: 'all-mpnet-base-v2' - sentence_bert_batch: 4 - qa_ctx: - devices: '0,1,2' - sentence_bert: 'facebook-dpr-ctx_encoder-multiset-base' - sentence_bert_batch: 4 - qa_question: - devices: '0,1,2' - sentence_bert: 'facebook-dpr-question_encoder-multiset-base' - sentence_bert_batch: 4 - services: - - type: FaissRetrievalService - faiss_devices: '0,1,2' - faiss_index: null # the faiss index file that is used to find KNN - nprobe: 100 - retrieval_index: null - query_bert: 'default' # the bert model to encode the query str - - type: DynamicFaissRetrievalService - faiss_devices: '0,1,2' - faiss_index: null # the faiss index to load from file, if null, start from scratch - store_file: null # the retrieval service storage to load from file, if null, start from scratch - chunk_size: 64 - stride: 32 - ctx_bert: 'qa_ctx' # the bert model to encode the ctx that is used to construct the dynamic retrieval index - query_bert: 'qa_question' # the bert model to encode the query str - output_filename: 'dynamic_db' # the filename of serialized dynamic retrieval service, used for both Faiss index and data storage + combo_service: + service_ip: '0.0.0.0' + service_port: 17181 server: False # whether launch the API server port: 5555 # the port number for the inference server web_server: False # whether launch the web inference server diff --git a/nemo/collections/nlp/modules/common/retro_inference_strategies.py b/nemo/collections/nlp/modules/common/retro_inference_strategies.py index fa350ecdbc97..0199cc5038e5 100644 --- a/nemo/collections/nlp/modules/common/retro_inference_strategies.py +++ b/nemo/collections/nlp/modules/common/retro_inference_strategies.py @@ -21,12 +21,7 @@ import torch.distributed as dist from nemo.collections.nlp.modules.common.lm_utils import pad_batch -from nemo.collections.nlp.modules.common.megatron.retrieval_services.bert_service import start_sentence_bert_server -from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import ( - ComboRetrievalService, - DynamicFaissRetrievalService, - FaissRetrievalService, -) +from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import ComboRetrievalService from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy @@ -37,30 +32,14 @@ def __init__(self, model, **args): self.frequent_query = args['frequent_query'] self.pad_token_for_retrieval = args['pad_tokens'] self.store_retrieved = args['store_retrieved'] - weights = args['weights'] self.store = dist.FileStore('/tmp/filestore_eval', -1) self.store.set('neighbors', str(args['neighbors'])) self.megatron_lm_compatible = args['megatron_lm_compatible'] - # start the sentence bert server - for name in args['sentence_bert']: - conf = args['sentence_bert'][name] - start_sentence_bert_server(tokenizer=self.model.tokenizer, name=name, **conf) - services = [] - for service_conf in args['services']: - if service_conf['type'] == 'FaissRetrievalService': - del service_conf['type'] - service = FaissRetrievalService(tokenizer=self.model.tokenizer, **service_conf) - services.append(service) - elif service_conf['type'] == 'DynamicFaissRetrievalService': - del service_conf['type'] - service = DynamicFaissRetrievalService(tokenizer=self.model.tokenizer, **service_conf) - services.append(service) - else: - raise ValueError(f'no such service {service_conf["type"]} implemented') - self.service = ComboRetrievalService(retrieval_services=services, weights=weights, store=self.store) + combo_cfg = args['combo_service'] + self.service = ComboRetrievalService(tokenizer=self.model.tokenizer, service_ip=combo_cfg['service_ip'], service_port=combo_cfg['service_port']) self.retrieved = [] self.retrieved_text = [] - self.chunk_size = self.service.chunk_size + self.chunk_size = self.model.cfg.chunk_size def update_neighbors(self, neighbors): # dynamically change the number of neighbors during the query From 43268424bb91683d7c7fffe424d229b018cc93dd Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 2 Feb 2023 02:22:27 +0000 Subject: [PATCH 10/17] seperate text generation server Signed-off-by: Yi Dong --- .../common/retro_inference_strategies.py | 4 - .../modules/common/text_generation_server.py | 11 -- .../conf/retro_text_generation_server.yaml | 24 ++++ .../start_retro_model_service.py | 113 ++++++++++++++++++ 4 files changed, 137 insertions(+), 15 deletions(-) create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/retro_text_generation_server.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py diff --git a/nemo/collections/nlp/modules/common/retro_inference_strategies.py b/nemo/collections/nlp/modules/common/retro_inference_strategies.py index 0199cc5038e5..a06b8ef43e39 100644 --- a/nemo/collections/nlp/modules/common/retro_inference_strategies.py +++ b/nemo/collections/nlp/modules/common/retro_inference_strategies.py @@ -49,10 +49,6 @@ def update_neighbors(self, neighbors): def neighbors(self): return int(self.store.get('neighbors')) - def update_weights(self, weights): - # dynamically change the weights between different retrieval services - self.service.update_weights(weights) - def tokenize_batch(self, sentences, max_len, add_BOS): """ convert the sentences into lists of tokens, pad them to the same length, add bos tokens if it is needed diff --git a/nemo/collections/nlp/modules/common/text_generation_server.py b/nemo/collections/nlp/modules/common/text_generation_server.py index 530042fa595b..4bb50f1ab527 100644 --- a/nemo/collections/nlp/modules/common/text_generation_server.py +++ b/nemo/collections/nlp/modules/common/text_generation_server.py @@ -42,7 +42,6 @@ "top_k", "top_p", "neighbors", - "weights", "repetition_penalty", "min_tokens_to_generate", ] @@ -157,14 +156,6 @@ def put(self): if neighbors < 0: return "num of neighbors must be an integer no less than 0" - weights = None - if "weights" in request.get_json(): - weights = request.get_json()["weights"] - if not (type(weights) == int or type(weights) == float): - return "weights must be a positive number less than or equal to 1.0" - if not (0.0 <= weights <= 1.0): - return "weights must be a positive number less than or equal to 1.0" - with lock: # Need to get lock to keep multiple threads from hitting code MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate extra = {} @@ -178,8 +169,6 @@ def put(self): ): if neighbors is not None: self.inference_strategy.update_neighbors(neighbors) - if weights is not None: - self.inference_strategy.update_weights([weights, 1 - weights]) output = generate( self.model, diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_text_generation_server.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_text_generation_server.yaml new file mode 100644 index 000000000000..416bd8ae9433 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_text_generation_server.yaml @@ -0,0 +1,24 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 16 # 16, 32, or bf16 + +# tensor_model_parallel_size: 1 +# pipeline_model_parallel_size: 1 +# pipeline_model_parallel_split_rank: 0 # used for encoder and decoder model +retro_model_file: null # RETRO nemo file path + + +########### Faiss service parameters ######## +retrieval_service: + strategy: RetroModelTextGenerationStrategy # choose customized inference strategy + neighbors: 4 + frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens + pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once + store_retrieved: False # whether store the retrieved documents, so it can be checked + combo_service: + service_ip: '0.0.0.0' + service_port: 17181 +port: 5555 # the port number for the inference server \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py new file mode 100644 index 000000000000..43ed44b62490 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py @@ -0,0 +1,113 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import threading + +import torch +from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from torch.utils.data import DataLoader + +from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel +from nemo.collections.nlp.modules.common.megatron_web_server import get_retro_demo +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.core.config import hydra_runner + +try: + from apex.transformer import parallel_state + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +""" +This is the script to launch RETRO Model text generation server. + +Usage: + Assume the model has TP=1, PP=1 + run greedy inference from a nemo file: + python megatron_retro_eval.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.precision=16 \ + inference.tokens_to_generate=128 \ + inference.greedy=True \ + retro_model_file=path_to_retro_nemo_file \ + tensor_model_parallel_size=1 \ + pipeline_model_parallel_size=1 \ + retrieval_service.faiss_devices='0' \ + retrieval_service.faiss_index=path_to_faiss_index \ + retrieval_service.retrieval_index=path_to_retrieval_dataset \ + retrieval_service.neighbors=20 +""" + + +@hydra_runner(config_path="conf", config_name="retro_text_generation_server") +def main(cfg) -> None: + trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer) + + model_path = cfg.retro_model_file + + save_restore_connector = NLPSaveRestoreConnector() + + if os.path.isdir(model_path): + save_restore_connector.model_extracted_dir = model_path + + model_cfg = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + ) + + with open_dict(model_cfg): + model_cfg.precision = trainer.precision + model_cfg.sequence_parallel = False + model_cfg.activations_checkpoint_granularity = None + model_cfg.activations_checkpoint_method = None + + model = MegatronRetrievalModel.restore_from( + model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + ) + + # check whether the DDP is initialized + if parallel_state.is_unitialized(): + + def dummy(): + return + + if model.trainer.strategy.launcher is not None: + model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) + model.trainer.strategy.setup_environment() + + retrieval_service = OmegaConf.to_container(cfg.retrieval_service) + model.set_inference_config(None, retrieval_service) + + # running text generation, use inference server + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + server = MegatronServer(model.cuda(), inference_strategy=model.inference_strategy) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda(), strategy=model.inference_strategy) + + +if __name__ == '__main__': + main() From 0375a67c198a4f2589e46aa809048782b8b3b59f Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 2 Feb 2023 03:03:36 +0000 Subject: [PATCH 11/17] added webserver Signed-off-by: Yi Dong --- .../megatron/retrieval_services/util.py | 16 ++ .../nlp/modules/common/megatron_web_server.py | 228 ++++++++---------- .../conf/retro_web_server.yaml | 10 + .../start_web_service.py | 29 +++ 4 files changed, 158 insertions(+), 125 deletions(-) create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/conf/retro_web_server.yaml create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py index 2e9400bb57d1..0dd607e19b31 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py @@ -28,3 +28,19 @@ def request_data(data, ip='localhost', port=None): resp = requests.put(f'http://{ip}:{port}/knn', data=json.dumps(data), headers=headers) return resp.json() + +def text_generation(data, ip='localhost', port=None): + resp = requests.put(f'http://{ip}:{port}/generate', data=json.dumps(data), headers=headers) + return resp.json() + +def convert_retrieved_to_md(retrieved): + output_str = '' + for item in retrieved: + output_str += f'' + for i, neighbor in enumerate(item['neighbors']): + if i == 0: + output_str += f"" + else: + output_str += f"" + output_str += '
QueryRetrieved Doc
{item["query"]}{neighbor}
{neighbor}
' + return output_str diff --git a/nemo/collections/nlp/modules/common/megatron_web_server.py b/nemo/collections/nlp/modules/common/megatron_web_server.py index bd1d06107716..ecad78065e1e 100644 --- a/nemo/collections/nlp/modules/common/megatron_web_server.py +++ b/nemo/collections/nlp/modules/common/megatron_web_server.py @@ -12,30 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json - import gradio as gr -import requests - -from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import PORT_NUM_DYN - -PORT_NUM = 5555 -headers = {"Content-Type": "application/json"} - -def request_data(data, port_num=PORT_NUM): - resp = requests.put('http://localhost:{}/generate'.format(port_num), data=json.dumps(data), headers=headers) - output_json = resp.json() - return output_json +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, text_generation, convert_retrieved_to_md +__all__ = ['RetroDemoWebApp', 'get_demo'] -def update_index(data, port_num=PORT_NUM_DYN): - resp = requests.put('http://localhost:{}/knn'.format(port_num), data=json.dumps(data), headers=headers) - output_json = resp.json() - return output_json - -def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition): +def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, port=5555): data = { "sentences": [prompt], "tokens_to_generate": int(token_to_gen), @@ -48,61 +32,10 @@ def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_ "repetition_penalty": repetition, "min_tokens_to_generate": int(min_tokens), } - sentences = request_data(data)['sentences'] + sentences = text_generation(data, port=port)['sentences'] return sentences[0] -def convert_retrieved_to_md(retrieved): - output_str = '' - for item in retrieved: - output_str += f'' - for i, neighbor in enumerate(item['neighbors']): - if i == 0: - output_str += f"" - else: - output_str += f"" - output_str += '
QueryRetrieved Doc
{item["query"]}{neighbor}
{neighbor}
' - return output_str - - -def get_retro_generation( - prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, neighbors, weights -): - data = { - "sentences": [prompt], - "tokens_to_generate": int(token_to_gen), - "temperature": temp, - "add_BOS": add_BOS, - "top_k": top_k, - "top_p": top_p, - "greedy": greedy, - "all_probs": False, - "repetition_penalty": repetition, - "min_tokens_to_generate": int(min_tokens), - "neighbors": int(neighbors), - "weights": weights, - } - output_json = request_data(data) - sentences = output_json['sentences'] - retrieved = output_json['retrieved'] - return sentences[0], convert_retrieved_to_md(retrieved) - - -def add_doc(doc, add_eos): - data = { - "sentences": [doc], - "add_eos": add_eos, - } - return update_index(data) - - -def reset_index(): - data = {"reset": True} - resp = requests.put('http://localhost:{}/knn'.format(PORT_NUM_DYN), data=json.dumps(data), headers=headers) - output_json = resp.json() - return output_json - - def get_demo(share, username, password): with gr.Blocks() as demo: with gr.Row(): @@ -143,57 +76,102 @@ def get_demo(share, username, password): demo.launch(share=share, server_port=13570, server_name='0.0.0.0', auth=(username, password)) -def get_retro_demo(share, username, password): - with gr.Blocks(css="table, th, td { border: 1px solid blue; table-layout: fixed; width: 100%; }") as demo: - with gr.Row(): - with gr.Column(scale=2, width=200): - greedy_flag = gr.Checkbox(label="Greedy", value=True) - add_BOS = gr.Checkbox(label="Add BOS token", value=False) - token_to_gen = gr.Number(label='Number of Tokens to generate', value=30, type=int) - min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int) - temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1) - top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P') - top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K') - repetition_penality = gr.Slider( - minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty' - ) - k_neighbors = gr.Slider(minimum=0, maximum=50, step=1, value=2, label='Retrieved Documents') - weights = gr.Slider( - minimum=0.0, maximum=1.0, value=1.0, label='Weight for the first Retrieval', step=0.02 - ) - add_retrival_doc = gr.Textbox(label="Add New Retrieval Doc", value="", lines=5,) - add_EOS = gr.Checkbox(label="Add EOS token to Retrieval Doc", value=False) - with gr.Row(): - add_btn = gr.Button(value="Add") - reset_btn = gr.Button(value="Reset Index") - output_status = gr.Label(value='') - add_btn.click(add_doc, inputs=[add_retrival_doc, add_EOS], outputs=[output_status]) - reset_btn.click(reset_index, inputs=[], outputs=[output_status]) - - with gr.Column(scale=1, min_width=800): - input_prompt = gr.Textbox( - label="Input", - value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?", - lines=5, - ) - output_box = gr.Textbox(value="", label="Output") - btn = gr.Button(value="Submit") - output_retrieval = gr.HTML() - btn.click( - get_retro_generation, - inputs=[ - input_prompt, - greedy_flag, - add_BOS, - token_to_gen, - min_token_to_gen, - temperature, - top_p, - top_k, - repetition_penality, - k_neighbors, - weights, - ], - outputs=[output_box, output_retrieval], - ) - demo.launch(share=share, server_port=13570, server_name='0.0.0.0', auth=(username, password)) +class RetroDemoWebApp: + + def __init__(self, text_service_ip, text_service_port, combo_service_ip, combo_service_port): + self.text_service_ip = text_service_ip + self.text_service_port = text_service_port + self.combo_service_ip = combo_service_ip + self.combo_service_port = combo_service_port + + def get_retro_generation(self, prompt, greedy, add_BOS, token_to_gen, + min_tokens, temp, top_p, top_k, repetition, + neighbors, weight): + data = { + "sentences": [prompt], + "tokens_to_generate": int(token_to_gen), + "temperature": temp, + "add_BOS": add_BOS, + "top_k": top_k, + "top_p": top_p, + "greedy": greedy, + "all_probs": False, + "repetition_penalty": repetition, + "min_tokens_to_generate": int(min_tokens), + "neighbors": int(neighbors), + } + self.update_weight(weight) + output_json = text_generation(data, self.text_service_ip, self.text_service_port) + sentences = output_json['sentences'] + retrieved = output_json['retrieved'] + return sentences[0], convert_retrieved_to_md(retrieved) + + def update_weight(self, weight): + data = {"update_weight": [weight, 1.0 - weight]} + return request_data(data, self.combo_service_ip, self.combo_service_port) + + def add_doc(self, doc, add_eos): + data = { + "sentences": [doc], + "add_eos": add_eos, + } + return request_data(data, self.combo_service_ip, self.combo_service_port) + + def reset_index(self): + data = {"reset": None} + return request_data(data, self.combo_service_ip, self.combo_service_port) + + def run_demo(self, share, username, password, port): + with gr.Blocks(css="table, th, td { border: 1px solid blue; table-layout: fixed; width: 100%; }") as demo: + with gr.Row(): + with gr.Column(scale=2, width=200): + greedy_flag = gr.Checkbox(label="Greedy", value=True) + add_BOS = gr.Checkbox(label="Add BOS token", value=False) + token_to_gen = gr.Number(label='Number of Tokens to generate', value=30, type=int) + min_token_to_gen = gr.Number(label='Min number of Tokens to generate', value=1, type=int) + temperature = gr.Slider(minimum=0.0, maximum=10.0, value=1.0, label='Temperature', step=0.1) + top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.02, value=0.9, label='Top P') + top_k = gr.Slider(minimum=0, maximum=10000, step=2, value=0, label='Top K') + repetition_penality = gr.Slider( + minimum=1.0, maximum=5.0, step=0.02, value=1.2, label='Repetition penalty' + ) + k_neighbors = gr.Slider(minimum=0, maximum=50, step=1, value=2, label='Retrieved Documents') + weight = gr.Slider( + minimum=0.0, maximum=1.0, value=1.0, label='Weight for the Static Retrieval DB', step=0.02 + ) + add_retrival_doc = gr.Textbox(label="Add New Retrieval Doc", value="", lines=5,) + add_EOS = gr.Checkbox(label="Add EOS token to Retrieval Doc", value=False) + with gr.Row(): + add_btn = gr.Button(value="Add") + reset_btn = gr.Button(value="Reset Index") + output_status = gr.Label(value='') + add_btn.click(self.add_doc, inputs=[add_retrival_doc, add_EOS], outputs=[output_status]) + reset_btn.click(self.reset_index, inputs=[], outputs=[output_status]) + + with gr.Column(scale=1, min_width=800): + input_prompt = gr.Textbox( + label="Input", + value="Ariel was playing basketball. 1 of her shots went in the hoop. 2 of her shots did not go in the hoop. How many shots were there in total?", + lines=5, + ) + output_box = gr.Textbox(value="", label="Output") + btn = gr.Button(value="Submit") + output_retrieval = gr.HTML() + btn.click( + self.get_retro_generation, + inputs=[ + input_prompt, + greedy_flag, + add_BOS, + token_to_gen, + min_token_to_gen, + temperature, + top_p, + top_k, + repetition_penality, + k_neighbors, + weight, + ], + outputs=[output_box, output_retrieval], + ) + demo.launch(share=share, server_port=port, server_name='0.0.0.0', auth=(username, password)) diff --git a/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_web_server.yaml b/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_web_server.yaml new file mode 100644 index 000000000000..28599b616055 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/conf/retro_web_server.yaml @@ -0,0 +1,10 @@ + +text_service_ip: '0.0.0.0' +text_service_port: 5555 +combo_service_ip: '0.0.0.0' +combo_service_port: 17181 + +share: False # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +port: 7777 \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py new file mode 100644 index 000000000000..8ba50931b90d --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +""" + +from nemo.core.config import hydra_runner +from nemo.collections.nlp.modules.common.megatron_web_server import RetroDemoWebApp + + +@hydra_runner(config_path="conf", config_name="retro_web_server") +def main(cfg) -> None: + + demo = RetroDemoWebApp(cfg.text_service_ip, cfg.text_service_port, cfg.combo_service_ip, cfg.combo_service_port) + demo.run_demo(cfg.share, cfg.username, cfg.password, cfg.port) + + +if __name__ == "__main__": + main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file From b084db85afa1b204768aa47bfef1a302d0366514 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 2 Feb 2023 03:41:53 +0000 Subject: [PATCH 12/17] clean up and fix zero neighbor issue Signed-off-by: Yi Dong --- .../conf/megatron_retro_inference.yaml | 8 +-- .../language_modeling/megatron_retro_eval.py | 55 ++++++------------- .../common/retro_inference_strategies.py | 16 ++++-- .../start_retro_model_service.py | 5 -- 4 files changed, 27 insertions(+), 57 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml index 3160809d54d8..86f019d8792e 100644 --- a/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml +++ b/examples/nlp/language_modeling/conf/megatron_retro_inference.yaml @@ -41,10 +41,4 @@ retrieval_service: store_retrieved: False # whether store the retrieved documents, so it can be checked combo_service: service_ip: '0.0.0.0' - service_port: 17181 -server: False # whether launch the API server -port: 5555 # the port number for the inference server -web_server: False # whether launch the web inference server -share: False # whether create a public URL -username: test # user name for web client -password: test2 # password for web client \ No newline at end of file + service_port: 17181 \ No newline at end of file diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index 2f1be7537b3a..d4a051b25e7b 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -13,18 +13,13 @@ # limitations under the License. import os -import threading -import torch from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel -from nemo.collections.nlp.modules.common.megatron_web_server import get_retro_demo -from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer -from nemo.collections.nlp.modules.common.text_generation_utils import generate from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core.config import hydra_runner @@ -114,41 +109,23 @@ def dummy(): retrieval_service = OmegaConf.to_container(cfg.retrieval_service) model.set_inference_config(config, retrieval_service) - # running text generation, use inference server - if cfg.server: - if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: - if cfg.web_server: - thread = threading.Thread( - target=get_retro_demo, daemon=True, args=(cfg.share, cfg.username, cfg.password) - ) - thread.start() - server = MegatronServer(model.cuda(), inference_strategy=model.inference_strategy) - server.run("0.0.0.0", port=cfg.port) - - while True: - choice = torch.cuda.LongTensor(1) - torch.distributed.broadcast(choice, 0) - if choice[0].item() == 0: - generate(model.cuda(), strategy=model.inference_strategy) + if not cfg.use_predict_method: + # First method of running text generation, call model.generate method + response = model.generate( + inputs=OmegaConf.to_container(cfg.prompts), + length_params=length_params, + sampling_params=sampling_params, + strategy=model.inference_strategy, + ) else: - - if not cfg.use_predict_method: - # First method of running text generation, call model.generate method - response = model.generate( - inputs=OmegaConf.to_container(cfg.prompts), - length_params=length_params, - sampling_params=sampling_params, - strategy=model.inference_strategy, - ) - else: - # Second method of running text generation, call trainer.predict - ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) - request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) - response = trainer.predict(model, request_dl) - - print("***************************") - print(response) - print("***************************") + # Second method of running text generation, call trainer.predict + ds = RequestDataSet(OmegaConf.to_container(cfg.prompts)) + request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size) + response = trainer.predict(model, request_dl) + + print("***************************") + print(response) + print("***************************") if __name__ == '__main__': diff --git a/nemo/collections/nlp/modules/common/retro_inference_strategies.py b/nemo/collections/nlp/modules/common/retro_inference_strategies.py index a06b8ef43e39..1060fdc0d23c 100644 --- a/nemo/collections/nlp/modules/common/retro_inference_strategies.py +++ b/nemo/collections/nlp/modules/common/retro_inference_strategies.py @@ -202,9 +202,12 @@ def prepare_batch_at_step( retrieved_mask = torch.ones_like(retrieved, dtype=torch.bool) else: retrieved_mask = retrieved != tokenizer.pad_id - if len(retrieved) == 0: - retrieved = torch.tensor([-1] * micro_batch_size) - retrieved_mask = torch.tensor([-1] * micro_batch_size) + if retrieved.numel() == 0: + # add empty retrieved + retrieved = torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()).unsqueeze(0).repeat(1, len(self.retrieved), 1, 1) + retrieved_mask = retrieved != tokenizer.pad_id + # retrieved = torch.tensor([-1] * micro_batch_size) + # retrieved_mask = torch.tensor([-1] * micro_batch_size) """Prepare batch for each of the inference steps""" # attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) @@ -343,9 +346,10 @@ def prepare_batch_at_step( retrieved_mask = torch.ones_like(retrieved, dtype=torch.bool) else: retrieved_mask = retrieved != tokenizer.pad_id - if len(retrieved) == 0: - retrieved = torch.tensor([-1] * micro_batch_size) - retrieved_mask = torch.tensor([-1] * micro_batch_size) + if retrieved.numel() == 0: + # add empty retrieved + retrieved = torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()).unsqueeze(0).repeat(1, len(self.retrieved), 1, 1) + retrieved_mask = retrieved != tokenizer.pad_id """Prepare batch for each of the inference steps""" # attention_mask_repeat = torch.concat([self.attention_mask for _ in range(micro_batch_size)]) diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py index 43ed44b62490..6956028c0ab7 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py @@ -13,19 +13,14 @@ # limitations under the License. import os -import threading import torch -from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer -from torch.utils.data import DataLoader from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel -from nemo.collections.nlp.modules.common.megatron_web_server import get_retro_demo from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer from nemo.collections.nlp.modules.common.text_generation_utils import generate -from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core.config import hydra_runner From a7f63b7a2708113efe73eb292a0b1b133be33652 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 2 Feb 2023 21:28:20 +0000 Subject: [PATCH 13/17] fix the style Signed-off-by: Yi Dong --- .../combo_retrieval_server.py | 35 ++++++++----------- .../dynamic_retrieval_server.py | 25 +++++++++---- .../retrieval_services/retrieval_service.py | 17 +++------ .../static_retrieval_server.py | 6 ++-- .../megatron/retrieval_services/util.py | 3 +- .../nlp/modules/common/megatron_web_server.py | 13 ++++--- .../common/retro_inference_strategies.py | 16 +++++++-- .../start_bert_service.py | 17 ++++----- .../start_combo_retrieval_service.py | 7 ++-- .../start_dynamic_retrieval_service.py | 6 ++-- .../start_static_retrieval_service.py | 2 +- .../start_web_service.py | 4 +-- 12 files changed, 84 insertions(+), 67 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py index 09e44d85d5b7..6e6e2d045e4f 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py @@ -23,13 +23,17 @@ import numpy as np import torch from flask import Flask, jsonify, request -from flask_restful import Api -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import DynamicFaissRetrievalService, FaissRetrievalService -from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import FaissRetrievalResource -from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock from flask_restful import Api, Resource +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.modules.common.megatron.retrieval_services.retrieval_service import ( + DynamicFaissRetrievalService, + FaissRetrievalService, +) +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import ( + FaissRetrievalResource, +) +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import lock, request_data weights = None @@ -156,9 +160,7 @@ class ComboRetrievalServer(object): """ def __init__( - self, - tokenizer: TokenizerSpec, - services_cfg: list, + self, tokenizer: TokenizerSpec, services_cfg: list, ): self.app = Flask(__name__, static_url_path='') services = [] @@ -167,14 +169,12 @@ def __init__( weights.append(service_cfg.weight) if service_cfg.type == 'FaissRetrievalService': service = FaissRetrievalService( - tokenizer=tokenizer, - service_ip=service_cfg.service_ip, - service_port=service_cfg.service_port) + tokenizer=tokenizer, service_ip=service_cfg.service_ip, service_port=service_cfg.service_port + ) elif service_cfg.type == 'DynamicFaissRetrievalService': service = DynamicFaissRetrievalService( - tokenizer=tokenizer, - service_ip=service_cfg.service_ip, - service_port=service_cfg.service_port) + tokenizer=tokenizer, service_ip=service_cfg.service_ip, service_port=service_cfg.service_port + ) else: raise ValueError(f'Unsupported retrieval service {service_cfg.type}') services.append(service) @@ -183,12 +183,7 @@ def __init__( api = Api(self.app) api.add_resource( - ComboRetrievalResource, - '/knn', - resource_class_args=[ - services, - self.weight_container, - ], + ComboRetrievalResource, '/knn', resource_class_args=[services, self.weight_container,], ) def run(self, url, port=None): diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py index 407ceb9ab2be..6b4f98b48d1a 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/dynamic_retrieval_server.py @@ -17,18 +17,20 @@ import pickle import threading import time -from typing import List from collections import namedtuple +from typing import List import faiss import numpy as np import torch from flask import Flask, jsonify, request from flask_restful import Api -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import FaissRetrievalResource -from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import ( + FaissRetrievalResource, +) +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import lock, request_data # define this type to mimic the indexed dataset DType = namedtuple('DType', ['dtype']) @@ -67,7 +69,19 @@ class DynamicRetrievalResource(FaissRetrievalResource): The PUT method is to get KNN tokens, add new chunks, reset index. """ - def __init__(self, index, tokenizer, chunk_size, stride, store, ctx_bert_ip, ctx_bert_port, query_bert_ip, query_bert_port, output_filename): + def __init__( + self, + index, + tokenizer, + chunk_size, + stride, + store, + ctx_bert_ip, + ctx_bert_port, + query_bert_ip, + query_bert_port, + output_filename, + ): super().__init__(index, tokenizer, store, query_bert_ip, query_bert_port) self.chunk_size = chunk_size self.stride = stride @@ -214,4 +228,3 @@ def __init__( def run(self, url, port=None): threading.Thread(target=lambda: self.app.run(host=url, threaded=True, port=port)).start() - diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index b01f331b45a4..150683f03edd 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -65,10 +65,7 @@ class FaissRetrievalService(RetrievalService): """ def __init__( - self, - tokenizer: TokenizerSpec, - service_ip: str = None, - service_port: int = None, + self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, ): self.updatable = False self.tokenizer = tokenizer @@ -102,11 +99,9 @@ class DynamicFaissRetrievalService(FaissRetrievalService): It implements the retrieval services interface, has a simple client to add, reset and query the dynamic retrieval index. """ + def __init__( - self, - tokenizer: TokenizerSpec, - service_ip: str = None, - service_port: int = None, + self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, ): super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port) self.updatable = True @@ -151,11 +146,9 @@ class ComboRetrievalService(DynamicFaissRetrievalService): Combo retrieval service client class. It implements the retrieval services interface, has a simple client to add, reset, query, update weights """ + def __init__( - self, - tokenizer: TokenizerSpec, - service_ip: str = None, - service_port: int = None, + self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, ): super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py index 276a6f3e796a..b8c443989fbe 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/static_retrieval_server.py @@ -27,7 +27,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.nlp.data.language_modeling.megatron.indexed_retrieval_dataset import MMapRetrievalIndexedDataset -from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, lock +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import lock, request_data class FaissRetrievalResource(Resource): @@ -133,7 +133,9 @@ def __init__( self.ds = MMapRetrievalIndexedDataset(retrieval_index) api = Api(self.app) api.add_resource( - FaissRetrievalResource, '/knn', resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_ip, query_bert_port], + FaissRetrievalResource, + '/knn', + resource_class_args=[self.index, self.tokenizer, self.ds, query_bert_ip, query_bert_port], ) def run(self, url, port=None): diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py index 0dd607e19b31..e88c8df0cbbc 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/util.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import requests import json import threading +import requests headers = {"Content-Type": "application/json"} @@ -33,6 +33,7 @@ def text_generation(data, ip='localhost', port=None): resp = requests.put(f'http://{ip}:{port}/generate', data=json.dumps(data), headers=headers) return resp.json() + def convert_retrieved_to_md(retrieved): output_str = '' for item in retrieved: diff --git a/nemo/collections/nlp/modules/common/megatron_web_server.py b/nemo/collections/nlp/modules/common/megatron_web_server.py index ecad78065e1e..d63cbb744d8b 100644 --- a/nemo/collections/nlp/modules/common/megatron_web_server.py +++ b/nemo/collections/nlp/modules/common/megatron_web_server.py @@ -14,7 +14,11 @@ import gradio as gr -from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import request_data, text_generation, convert_retrieved_to_md +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import ( + convert_retrieved_to_md, + request_data, + text_generation, +) __all__ = ['RetroDemoWebApp', 'get_demo'] @@ -77,16 +81,15 @@ def get_demo(share, username, password): class RetroDemoWebApp: - def __init__(self, text_service_ip, text_service_port, combo_service_ip, combo_service_port): self.text_service_ip = text_service_ip self.text_service_port = text_service_port self.combo_service_ip = combo_service_ip self.combo_service_port = combo_service_port - def get_retro_generation(self, prompt, greedy, add_BOS, token_to_gen, - min_tokens, temp, top_p, top_k, repetition, - neighbors, weight): + def get_retro_generation( + self, prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_p, top_k, repetition, neighbors, weight + ): data = { "sentences": [prompt], "tokens_to_generate": int(token_to_gen), diff --git a/nemo/collections/nlp/modules/common/retro_inference_strategies.py b/nemo/collections/nlp/modules/common/retro_inference_strategies.py index 1060fdc0d23c..2c267fe06e64 100644 --- a/nemo/collections/nlp/modules/common/retro_inference_strategies.py +++ b/nemo/collections/nlp/modules/common/retro_inference_strategies.py @@ -36,7 +36,9 @@ def __init__(self, model, **args): self.store.set('neighbors', str(args['neighbors'])) self.megatron_lm_compatible = args['megatron_lm_compatible'] combo_cfg = args['combo_service'] - self.service = ComboRetrievalService(tokenizer=self.model.tokenizer, service_ip=combo_cfg['service_ip'], service_port=combo_cfg['service_port']) + self.service = ComboRetrievalService( + tokenizer=self.model.tokenizer, service_ip=combo_cfg['service_ip'], service_port=combo_cfg['service_port'] + ) self.retrieved = [] self.retrieved_text = [] self.chunk_size = self.model.cfg.chunk_size @@ -204,7 +206,11 @@ def prepare_batch_at_step( retrieved_mask = retrieved != tokenizer.pad_id if retrieved.numel() == 0: # add empty retrieved - retrieved = torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()).unsqueeze(0).repeat(1, len(self.retrieved), 1, 1) + retrieved = ( + torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()) + .unsqueeze(0) + .repeat(1, len(self.retrieved), 1, 1) + ) retrieved_mask = retrieved != tokenizer.pad_id # retrieved = torch.tensor([-1] * micro_batch_size) # retrieved_mask = torch.tensor([-1] * micro_batch_size) @@ -348,7 +354,11 @@ def prepare_batch_at_step( retrieved_mask = retrieved != tokenizer.pad_id if retrieved.numel() == 0: # add empty retrieved - retrieved = torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()).unsqueeze(0).repeat(1, len(self.retrieved), 1, 1) + retrieved = ( + torch.tensor(self.service.get_knn(['a'], 0), device=torch.cuda.current_device()) + .unsqueeze(0) + .repeat(1, len(self.retrieved), 1, 1) + ) retrieved_mask = retrieved != tokenizer.pad_id """Prepare batch for each of the inference steps""" diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py index 43a0f68c432b..e65c8fdcbddc 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py @@ -38,14 +38,15 @@ def get_tokenizer(args): @hydra_runner(config_path="conf", config_name="bert_service") def main(cfg) -> None: tokenizer = get_tokenizer(cfg.tokenizer) - start_sentence_bert_server(cfg.name, - cfg.sentence_bert.devices, - tokenizer, - cfg.sentence_bert.sentence_bert, - cfg.sentence_bert.sentence_bert_batch, - port=cfg.sentence_bert.port - ) + start_sentence_bert_server( + cfg.name, + cfg.sentence_bert.devices, + tokenizer, + cfg.sentence_bert.sentence_bert, + cfg.sentence_bert.sentence_bert_batch, + port=cfg.sentence_bert.port, + ) if __name__ == "__main__": - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py index 02cf21d678b1..87e2d0dbda10 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py @@ -39,12 +39,9 @@ def get_tokenizer(args): def main(cfg) -> None: tokenizer = get_tokenizer(cfg.tokenizer) - server = ComboRetrievalServer( - tokenizer, - cfg.service.child_services - ) + server = ComboRetrievalServer(tokenizer, cfg.service.child_services) server.run("0.0.0.0", cfg.service.port) if __name__ == "__main__": - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py index 60aaecbc35c8..ca2d2c9861d3 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py @@ -14,7 +14,9 @@ """ """ -from nemo.collections.nlp.modules.common.megatron.retrieval_services.dynamic_retrieval_server import DynamicRetrievalServer +from nemo.collections.nlp.modules.common.megatron.retrieval_services.dynamic_retrieval_server import ( + DynamicRetrievalServer, +) from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.core.config import hydra_runner @@ -56,4 +58,4 @@ def main(cfg) -> None: if __name__ == "__main__": - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py index 936ebebbdc2f..c230efda4d40 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py @@ -51,4 +51,4 @@ def main(cfg) -> None: if __name__ == "__main__": - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py index 8ba50931b90d..77453beb417f 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py @@ -14,8 +14,8 @@ """ """ -from nemo.core.config import hydra_runner from nemo.collections.nlp.modules.common.megatron_web_server import RetroDemoWebApp +from nemo.core.config import hydra_runner @hydra_runner(config_path="conf", config_name="retro_web_server") @@ -26,4 +26,4 @@ def main(cfg) -> None: if __name__ == "__main__": - main() # noqa pylint: disable=no-value-for-parameter \ No newline at end of file + main() # noqa pylint: disable=no-value-for-parameter From 690c6f8f548c984f0529f1efab86bdf134f8d8ac Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 2 Feb 2023 21:37:21 +0000 Subject: [PATCH 14/17] add license Signed-off-by: Yi Dong --- .../common/megatron/retrieval_services/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py index e69de29bb2d1..4fc50543f1d2 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From bbecfad3142dfacd953f655ad0a1aa6e798f49a3 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Thu, 9 Feb 2023 19:12:27 +0000 Subject: [PATCH 15/17] fixed code QL Signed-off-by: Yi Dong --- .../megatron/retrieval_services/combo_retrieval_server.py | 8 +------- .../megatron/retrieval_services/retrieval_service.py | 7 +++---- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py index 6e6e2d045e4f..a58d77e5699f 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/combo_retrieval_server.py @@ -12,11 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 -import logging import pickle import threading -import time from typing import List, Union import faiss @@ -30,10 +27,7 @@ DynamicFaissRetrievalService, FaissRetrievalService, ) -from nemo.collections.nlp.modules.common.megatron.retrieval_services.static_retrieval_server import ( - FaissRetrievalResource, -) -from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import lock, request_data +from nemo.collections.nlp.modules.common.megatron.retrieval_services.util import lock weights = None diff --git a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py index 150683f03edd..fde17644260a 100644 --- a/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py +++ b/nemo/collections/nlp/modules/common/megatron/retrieval_services/retrieval_service.py @@ -65,9 +65,9 @@ class FaissRetrievalService(RetrievalService): """ def __init__( - self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, + self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, updatable: bool = False, ): - self.updatable = False + self.updatable = updatable self.tokenizer = tokenizer self.service_ip = service_ip self.service_port = service_port @@ -103,8 +103,7 @@ class DynamicFaissRetrievalService(FaissRetrievalService): def __init__( self, tokenizer: TokenizerSpec, service_ip: str = None, service_port: int = None, ): - super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port) - self.updatable = True + super().__init__(tokenizer=tokenizer, service_ip=service_ip, service_port=service_port, updatable=True) def add_docs_to_index(self, query: List[str], add_eos: bool = True): """ From 60d88ae1ee19771671597677c877388627284995 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 10 Feb 2023 04:09:18 +0000 Subject: [PATCH 16/17] added bash script to launch the demo Signed-off-by: Yi Dong --- .../service_launch_scripts/env_variables.sh | 33 ++++++ .../service_launch_scripts/launch_demo.sh | 108 ++++++++++++++++++ 2 files changed, 141 insertions(+) create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/env_variables.sh create mode 100644 scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh diff --git a/scripts/nlp_language_modeling/service_launch_scripts/env_variables.sh b/scripts/nlp_language_modeling/service_launch_scripts/env_variables.sh new file mode 100644 index 000000000000..390e1042021d --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/env_variables.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MERGE_FILE= +VOCAB_FILE= + +BERT_DEVICES=\'0,1,2,3\' +BERT_PORT=17190 +CONTEXT_BERT_PORT=17191 +QUERY_BERT_PORT=17192 + +STATIC_FAISS_INDEX= +STATIC_RETRIVAL_DB= +STATIC_RETRIEVAL_PORT=17179 + +DYNAMIC_RETRIEVAL_PORT=17180 +COMBO_RETRIEVAL_PORT=17181 + +RETRO_MODEL_PATH= +RETRO_MODEL_PORT=5555 +WEB_PORT=7777 +PASSWORD=test2 \ No newline at end of file diff --git a/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh b/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh new file mode 100644 index 000000000000..6da7ca98d8f4 --- /dev/null +++ b/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +depends_on () { + HOST=$1 + PORT=$2 + STATUS=$(curl -X PUT http://$HOST:$PORT >/dev/null 2>/dev/null; echo $?) + while [ $STATUS -ne 0 ] + do + echo "waiting for server ($HOST:$PORT) to be up" + sleep 10 + STATUS=$(curl -X PUT http://$HOST:$PORT >/dev/null 2>/dev/null; echo $?) + done + echo "server ($HOST:$PORT) is up running" +} + +load_variables() { + PYTHONUNBUFFERED=TRUE + full_path=$(realpath $0) + dir_path=$(dirname $full_path) + source $dir_path/env_variables.sh +} + +# load the environment variables +load_variables + + +# launch bert model service +python scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.py \ + tokenizer.merge_file=$MERGE_FILE \ + tokenizer.vocab_file=$VOCAB_FILE \ + sentence_bert.sentence_bert=all-mpnet-base-v2 \ + sentence_bert.devices=$BERT_DEVICES \ + sentence_bert.port=${BERT_PORT} & + + +depends_on "0.0.0.0" ${BERT_PORT} + +# launch static retrieval service +python scripts/nlp_language_modeling/service_launch_scripts/start_static_retrieval_service.py \ + tokenizer.merge_file=$MERGE_FILE \ + tokenizer.vocab_file=$VOCAB_FILE \ + service.faiss_devices=null \ + service.faiss_index=$STATIC_FAISS_INDEX \ + service.retrieval_index=$STATIC_RETRIVAL_DB \ + service.query_bert_port=${BERT_PORT} \ + service.port=${STATIC_RETRIEVAL_PORT} & + +# launch dynamic retrieval service +python scripts/nlp_language_modeling/service_launch_scripts/start_dynamic_retrieval_service.py \ + tokenizer.merge_file=$MERGE_FILE \ + tokenizer.vocab_file=$VOCAB_FILE \ + service.faiss_devices=null \ + service.ctx_bert_port=${BERT_PORT} \ + service.query_bert_port=${BERT_PORT} \ + service.port=${DYNAMIC_RETRIEVAL_PORT} & + +depends_on "0.0.0.0" ${STATIC_RETRIEVAL_PORT} +depends_on "0.0.0.0" ${DYNAMIC_RETRIEVAL_PORT} + +# launch combo service +python scripts/nlp_language_modeling/service_launch_scripts/start_combo_retrieval_service.py \ + tokenizer.merge_file=$MERGE_FILE \ + tokenizer.vocab_file=$VOCAB_FILE \ + service.child_services.0.service_port=${STATIC_RETRIEVAL_PORT} \ + service.child_services.1.service_port=${DYNAMIC_RETRIEVAL_PORT} \ + service.port=${COMBO_RETRIEVAL_PORT} & + +depends_on "0.0.0.0" ${COMBO_RETRIEVAL_PORT} + +# launch text generation server +python scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.accelerator=gpu \ + trainer.precision=16 \ + retro_model_file=$RETRO_MODEL_PATH \ + retrieval_service.strategy=RetroModelTextGenerationStrategy \ + retrieval_service.neighbors=2 \ + retrieval_service.pad_tokens=True \ + retrieval_service.store_retrieved=True \ + retrieval_service.combo_service.service_port=${COMBO_RETRIEVAL_PORT} \ + port=${RETRO_MODEL_PORT} & + +depends_on "0.0.0.0" $RETRO_MODEL_PORT + +# launch the web server + +python scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py \ + text_service_port=${RETRO_MODEL_PORT} \ + combo_service_port=${COMBO_RETRIEVAL_PORT} \ + share=True \ + username=test \ + password=${PASSWORD} \ + port=${WEB_PORT} \ No newline at end of file From c61db83174feea2eff3c9375131fc6ce35d1c207 Mon Sep 17 00:00:00 2001 From: Yi Dong Date: Fri, 10 Feb 2023 14:56:32 +0000 Subject: [PATCH 17/17] clean up Signed-off-by: Yi Dong --- requirements/requirements_nlp.txt | 2 +- .../service_launch_scripts/launch_demo.sh | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index fa7f232c1a41..b1b034f4ab27 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -5,7 +5,7 @@ fasttext flask_restful ftfy gdown -gradio==3.4.0 +gradio h5py ijson inflect diff --git a/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh b/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh index 6da7ca98d8f4..ae18f222f230 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh +++ b/scripts/nlp_language_modeling/service_launch_scripts/launch_demo.sh @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +clean_up() { + kill -- -$$ +} depends_on () { HOST=$1 @@ -46,7 +49,6 @@ python scripts/nlp_language_modeling/service_launch_scripts/start_bert_service.p sentence_bert.devices=$BERT_DEVICES \ sentence_bert.port=${BERT_PORT} & - depends_on "0.0.0.0" ${BERT_PORT} # launch static retrieval service @@ -105,4 +107,8 @@ python scripts/nlp_language_modeling/service_launch_scripts/start_web_service.py share=True \ username=test \ password=${PASSWORD} \ - port=${WEB_PORT} \ No newline at end of file + port=${WEB_PORT} + + +echo "clean up dameons: $$" +clean_up
QueryRetrieved Doc