forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Open source export and deploy modules (NVIDIA#8743)
* export and deploy modules Signed-off-by: Onur Yilmaz <[email protected]> * Add export tests Signed-off-by: Onur Yilmaz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address PR reviews Signed-off-by: Onur Yilmaz <[email protected]> * Add try except Signed-off-by: Onur Yilmaz <[email protected]> * Moved query_llm to nlp folder Signed-off-by: Onur Yilmaz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed lambada.json Signed-off-by: Onur Yilmaz <[email protected]> * Reverting the Jenkinsfile Signed-off-by: Onur Yilmaz <[email protected]> * Exclude deploy and export from the pip Signed-off-by: Onur Yilmaz <[email protected]> * Address the CodeQL issues Signed-off-by: Onur Yilmaz <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Addressing reviews Signed-off-by: Onur Yilmaz <[email protected]> * remove deploy test for now Signed-off-by: Onur Yilmaz <[email protected]> * Addressing CodeQL comments Signed-off-by: Onur Yilmaz <[email protected]> * wrap imports with try except Signed-off-by: Onur Yilmaz <[email protected]> * Add test data param and fix codeql issue Signed-off-by: Onur Yilmaz <[email protected]> --------- Signed-off-by: Onur Yilmaz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper <[email protected]>
- Loading branch information
1 parent
c7b4e5c
commit 143e162
Showing
39 changed files
with
7,954 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
from nemo.deploy.deploy_base import DeployBase | ||
from nemo.deploy.deploy_pytriton import DeployPyTriton | ||
from nemo.deploy.triton_deployable import ITritonDeployable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import importlib | ||
import logging | ||
from abc import ABC, abstractmethod | ||
|
||
use_pytorch_lightning = True | ||
try: | ||
from pytorch_lightning import Trainer | ||
except Exception: | ||
use_pytorch_lightning = False | ||
|
||
from nemo.deploy.triton_deployable import ITritonDeployable | ||
|
||
use_nemo = True | ||
try: | ||
from nemo.core.classes.modelPT import ModelPT | ||
except Exception: | ||
use_nemo = False | ||
|
||
|
||
LOGGER = logging.getLogger("NeMo") | ||
|
||
|
||
class DeployBase(ABC): | ||
def __init__( | ||
self, | ||
triton_model_name: str, | ||
triton_model_version: int = 1, | ||
checkpoint_path: str = None, | ||
model=None, | ||
max_batch_size: int = 128, | ||
port: int = 8000, | ||
address="0.0.0.0", | ||
allow_grpc=True, | ||
allow_http=True, | ||
streaming=False, | ||
pytriton_log_verbose=0, | ||
): | ||
self.checkpoint_path = checkpoint_path | ||
self.triton_model_name = triton_model_name | ||
self.triton_model_version = triton_model_version | ||
self.max_batch_size = max_batch_size | ||
self.model = model | ||
self.port = port | ||
self.address = address | ||
self.triton = None | ||
self.allow_grpc = allow_grpc | ||
self.allow_http = allow_http | ||
self.streaming = streaming | ||
self.pytriton_log_verbose = pytriton_log_verbose | ||
|
||
if checkpoint_path is None and model is None: | ||
raise Exception("Either checkpoint_path or model should be provided.") | ||
|
||
@abstractmethod | ||
def deploy(self): | ||
pass | ||
|
||
@abstractmethod | ||
def serve(self): | ||
pass | ||
|
||
@abstractmethod | ||
def run(self): | ||
pass | ||
|
||
@abstractmethod | ||
def stop(self): | ||
pass | ||
|
||
def _init_nemo_model(self): | ||
if self.checkpoint_path is not None: | ||
model_config = ModelPT.restore_from(self.checkpoint_path, return_config=True) | ||
module_path, class_name = DeployBase.get_module_and_class(model_config.target) | ||
cls = getattr(importlib.import_module(module_path), class_name) | ||
self.model = cls.restore_from(restore_path=self.checkpoint_path, trainer=Trainer()) | ||
self.model.freeze() | ||
|
||
# has to turn off activations_checkpoint_method for inference | ||
try: | ||
self.model.model.language_model.encoder.activations_checkpoint_method = None | ||
except AttributeError as e: | ||
LOGGER.warning(e) | ||
|
||
if self.model is None: | ||
raise Exception("There is no model to deploy.") | ||
|
||
self._is_model_deployable() | ||
|
||
def _is_model_deployable(self): | ||
if not issubclass(type(self.model), ITritonDeployable): | ||
raise Exception( | ||
"This model is not deployable to Triton." "nemo.deploy.ITritonDeployable class should be inherited" | ||
) | ||
else: | ||
return True | ||
|
||
@staticmethod | ||
def get_module_and_class(target: str): | ||
ln = target.rindex(".") | ||
return target[0:ln], target[ln + 1 : len(target)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
use_pytriton = True | ||
try: | ||
from pytriton.model_config import ModelConfig | ||
from pytriton.triton import Triton, TritonConfig | ||
except Exception: | ||
use_pytriton = False | ||
|
||
from nemo.deploy.deploy_base import DeployBase | ||
|
||
|
||
class DeployPyTriton(DeployBase): | ||
|
||
""" | ||
Deploys any models to Triton Inference Server that implements ITritonDeployable interface in nemo.deploy. | ||
Example: | ||
from nemo.deploy import DeployPyTriton, NemoQueryLLM | ||
from nemo.export import TensorRTLLM | ||
trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") | ||
trt_llm_exporter.export( | ||
nemo_checkpoint_path="/path/for/nemo/checkpoint", | ||
model_type="llama", | ||
n_gpus=1, | ||
) | ||
nm = DeployPyTriton(model=trt_llm_exporter, triton_model_name="model_name", port=8000) | ||
nm.deploy() | ||
nm.run() | ||
nq = NemoQueryLLM(url="localhost", model_name="model_name") | ||
prompts = ["hello, testing GPT inference", "another GPT inference test?"] | ||
output = nq.query_llm(prompts=prompts, max_output_len=100) | ||
print("prompts: ", prompts) | ||
print("") | ||
print("output: ", output) | ||
print("") | ||
prompts = ["Give me some info about Paris", "Do you think Londan is a good city to visit?", "What do you think about Rome?"] | ||
output = nq.query_llm(prompts=prompts, max_output_len=250) | ||
print("prompts: ", prompts) | ||
print("") | ||
print("output: ", output) | ||
print("") | ||
""" | ||
|
||
def __init__( | ||
self, | ||
triton_model_name: str, | ||
triton_model_version: int = 1, | ||
checkpoint_path: str = None, | ||
model=None, | ||
max_batch_size: int = 128, | ||
port: int = 8000, | ||
address="0.0.0.0", | ||
allow_grpc=True, | ||
allow_http=True, | ||
streaming=False, | ||
pytriton_log_verbose=0, | ||
): | ||
""" | ||
A nemo checkpoint or model is expected for serving on Triton Inference Server. | ||
Args: | ||
triton_model_name (str): Name for the service | ||
triton_model_version(int): Version for the service | ||
checkpoint_path (str): path of the nemo file | ||
model (ITritonDeployable): A model that implements the ITritonDeployable from nemo.deploy import ITritonDeployable | ||
max_batch_size (int): max batch size | ||
port (int) : port for the Triton server | ||
address (str): http address for Triton server to bind. | ||
""" | ||
|
||
super().__init__( | ||
triton_model_name=triton_model_name, | ||
triton_model_version=triton_model_version, | ||
checkpoint_path=checkpoint_path, | ||
model=model, | ||
max_batch_size=max_batch_size, | ||
port=port, | ||
address=address, | ||
allow_grpc=allow_grpc, | ||
allow_http=allow_http, | ||
streaming=streaming, | ||
pytriton_log_verbose=pytriton_log_verbose, | ||
) | ||
|
||
def deploy(self): | ||
|
||
""" | ||
Deploys any models to Triton Inference Server. | ||
""" | ||
|
||
self._init_nemo_model() | ||
|
||
try: | ||
if self.streaming: | ||
# TODO: can't set allow_http=True due to a bug in pytriton, will fix in latest pytriton | ||
triton_config = TritonConfig( | ||
log_verbose=self.pytriton_log_verbose, | ||
allow_grpc=self.allow_grpc, | ||
allow_http=self.allow_http, | ||
grpc_address=self.address, | ||
) | ||
self.triton = Triton(config=triton_config) | ||
self.triton.bind( | ||
model_name=self.triton_model_name, | ||
model_version=self.triton_model_version, | ||
infer_func=self.model.triton_infer_fn_streaming, | ||
inputs=self.model.get_triton_input, | ||
outputs=self.model.get_triton_output, | ||
config=ModelConfig(decoupled=True), | ||
) | ||
else: | ||
triton_config = TritonConfig( | ||
http_address=self.address, | ||
http_port=self.port, | ||
allow_grpc=self.allow_grpc, | ||
allow_http=self.allow_http, | ||
) | ||
self.triton = Triton(config=triton_config) | ||
self.triton.bind( | ||
model_name=self.triton_model_name, | ||
model_version=self.triton_model_version, | ||
infer_func=self.model.triton_infer_fn, | ||
inputs=self.model.get_triton_input, | ||
outputs=self.model.get_triton_output, | ||
config=ModelConfig(max_batch_size=self.max_batch_size), | ||
) | ||
except Exception as e: | ||
self.triton = None | ||
print(e) | ||
|
||
def serve(self): | ||
|
||
""" | ||
Starts serving the model and waits for the requests | ||
""" | ||
|
||
if self.triton is None: | ||
raise Exception("deploy should be called first.") | ||
|
||
try: | ||
self.triton.serve() | ||
except Exception as e: | ||
self.triton = None | ||
print(e) | ||
|
||
def run(self): | ||
|
||
""" | ||
Starts serving the model asynchronously. | ||
""" | ||
|
||
if self.triton is None: | ||
raise Exception("deploy should be called first.") | ||
|
||
self.triton.run() | ||
|
||
def stop(self): | ||
""" | ||
Stops serving the model. | ||
""" | ||
|
||
if self.triton is None: | ||
raise Exception("deploy should be called first.") | ||
|
||
self.triton.stop() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
use_query_llm = True | ||
try: | ||
from nemo.deploy.nlp.query_llm import NemoQueryLLM | ||
except Exception: | ||
use_query_llm = False |
Oops, something went wrong.