From 241b82e2bcafb2b98281b4dedc39b232bac66c43 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Wed, 22 Nov 2023 13:54:55 +0700 Subject: [PATCH 1/8] init --- .../model_serving/examples/vllm/Readme.md | 12 +++ .../examples/vllm/example/1/__init__.py | 0 .../examples/vllm/example/1/inference.py | 57 ++++++++++++++ .../examples/vllm/example/1/model.py | 74 +++++++++++++++++++ .../examples/vllm/example/1/test.py | 64 ++++++++++++++++ .../examples/vllm/example/1/weights/keep | 0 .../examples/vllm/example/config.pbtxt | 20 +++++ .../examples/vllm/example/requirements.txt | 5 ++ 8 files changed, 232 insertions(+) create mode 100644 clarifai/models/model_serving/examples/vllm/Readme.md create mode 100644 clarifai/models/model_serving/examples/vllm/example/1/__init__.py create mode 100644 clarifai/models/model_serving/examples/vllm/example/1/inference.py create mode 100644 clarifai/models/model_serving/examples/vllm/example/1/model.py create mode 100644 clarifai/models/model_serving/examples/vllm/example/1/test.py create mode 100644 clarifai/models/model_serving/examples/vllm/example/1/weights/keep create mode 100644 clarifai/models/model_serving/examples/vllm/example/config.pbtxt create mode 100644 clarifai/models/model_serving/examples/vllm/example/requirements.txt diff --git a/clarifai/models/model_serving/examples/vllm/Readme.md b/clarifai/models/model_serving/examples/vllm/Readme.md new file mode 100644 index 00000000..9f180775 --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/Readme.md @@ -0,0 +1,12 @@ +## vLLM (text-to-text) Example + +These can be used on the fly with minimal or no changes to test deploy vLLM models to the Clarifai platform. See the required files section for each model below. + +### Prerequisites: +* weights: Input your local weights or download them from huggingface to `./example/1/weights`. +Example download from huggingface: +``` +huggingface-cli download {MODEL_ID} --local-dir ./example/1/weights --local-dir-use-symlinks False --exclude {EXCLUDED FILE TYPES} +``` +* requirements.txt: update your requirements. +* inference.py: update LLM() paramters. It is recommended to use `gpu_memory_utilization=0.7`. diff --git a/clarifai/models/model_serving/examples/vllm/example/1/__init__.py b/clarifai/models/model_serving/examples/vllm/example/1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clarifai/models/model_serving/examples/vllm/example/1/inference.py b/clarifai/models/model_serving/examples/vllm/example/1/inference.py new file mode 100644 index 00000000..e27cc6d4 --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/example/1/inference.py @@ -0,0 +1,57 @@ +# This file contains boilerplate code to allow users write their model +# inference code that will then interact with the Triton Inference Server +# Python backend to serve end user requests. +# The module name, module path, class name & get_predictions() method names MUST be maintained as is +# but other methods may be added within the class as deemed fit provided +# they are invoked within the main get_predictions() inference method +# if they play a role in any step of model inference +"""User model inference script.""" + +import os +from pathlib import Path + +import numpy as np +from vllm import LLM, SamplingParams + +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config + +config = get_model_config(ModelTypes.text_to_text) + + +class InferenceModel: + """User model inference class.""" + + def __init__(self) -> None: + """ + Load inference time artifacts that are called frequently .e.g. models, tokenizers, etc. + in this method so they are loaded only once for faster inference. + """ + self.base_path: Path = os.path.dirname(__file__) + path = os.path.join(self.base_path, "weights") + self.model = LLM( + model=path, + dtype="float16", + gpu_memory_utilization=0.7, + swap_space=1, + #quantization="awq" + ) + + @config.inference.wrap_func + def get_predictions(self, input_data, **kwargs): + """ + Main model inference method. + + Args: + ----- + input_data: A single input data item to predict on. + Input data can be an image or text, etc depending on the model type. + + Returns: + -------- + One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + """ + sampling_params = SamplingParams(**kwargs) + output = self.model.generate(input_data, sampling_params) + generated_text = np.asarray(output[0].outputs[0].text, dtype=object) + + return config.inference.return_type(generated_text) diff --git a/clarifai/models/model_serving/examples/vllm/example/1/model.py b/clarifai/models/model_serving/examples/vllm/example/1/model.py new file mode 100644 index 00000000..36b54b37 --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/example/1/model.py @@ -0,0 +1,74 @@ +# Copyright 2023 Clarifai, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Triton inference server Python Backend Model.""" + +import os +import sys + +try: + import triton_python_backend_utils as pb_utils +except ModuleNotFoundError: + pass +from google.protobuf import text_format +from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters + + +class TritonPythonModel: + """ + Triton Python BE Model. + """ + + def initialize(self, args): + """ + Triton server init. + """ + args["model_repository"] = args["model_repository"].replace("/1/model.py", "") + sys.path.append(os.path.dirname(__file__)) + from inference import InferenceModel + + self.inference_obj = InferenceModel() + + # Read input_name from config file + self.config_msg = ModelConfig() + with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: + cfg = f.read() + text_format.Merge(cfg, self.config_msg) + self.input_names = [inp.name for inp in self.config_msg.input] + + def execute(self, requests): + """ + Serve model inference requests. + """ + responses = [] + + for request in requests: + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + + responses.append(inference_response) + + return responses diff --git a/clarifai/models/model_serving/examples/vllm/example/1/test.py b/clarifai/models/model_serving/examples/vllm/example/1/test.py new file mode 100644 index 00000000..1d7ff13d --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/example/1/test.py @@ -0,0 +1,64 @@ +import logging +import os +import unittest + +from clarifai.models.model_serving.models.default_test import DefaultTestInferenceModel + + +class CustomTestInferenceModel(DefaultTestInferenceModel): + """ + Run this file to test your implementation of InferenceModel in inference.py with default tests of Triton configuration and its output values based on basic predefined inputs + If you want to write custom testcase or just test output value. + Please follow these instrucitons: + 1. Name your test function with prefix "test" so that pytest can execute + 2. In order to obtain output of InferenceModel, call `self.triton_get_predictions(input_data)`. + 3. If your input is `image` and you have set custom size of it when building model repository, + call `self.preprocess(image)` to obtain correct resized input + 4. Run this test by calling + ```bash + pytest ./your_triton_folder/1/test.py + #to see std output + pytest --log-cli-level=INFO -s ./your_triton_folder/1/test.py + ``` + + ### Examples: + + test text-to-image output + ``` + def test_text_to_image_output(self): + text = "Test text" + output = self.triton_get_predictions(text) + image = output.image # uint8 np.ndarray image + #show or save + ``` + + test visual-classifier output + ``` + def test_visual_classifier(self): + image = cv2.imread("your/local/image.jpg") # Keep in mind of format of image (BGR or RGB) + output = self.triton_get_predictions(image) + scores = output.predicted_scores # np.ndarray + #process scores to get class id and its score + logger.info(result) + """ + + # Insert your inference parameters json path here + # or insert a dictionary of your_parameter_name and value, e.g dict(x=1.5, y="text", c=True) + # or Leave it as "" if you don't have it. + inference_parameters = "" + + ########### Initialization. Do not change it ########### + __test__ = True + + def setUp(self) -> None: + logging.info("Initializing...") + model_type = "text-to-text" # your model type + self.intitialize( + model_type, + repo_version_dir=os.path.dirname(__file__), + is_instance_kind_gpu=True, + inference_parameters=self.inference_parameters) + + ######################################################## + + +if __name__ == '__main__': + unittest.main() diff --git a/clarifai/models/model_serving/examples/vllm/example/1/weights/keep b/clarifai/models/model_serving/examples/vllm/example/1/weights/keep new file mode 100644 index 00000000..e69de29b diff --git a/clarifai/models/model_serving/examples/vllm/example/config.pbtxt b/clarifai/models/model_serving/examples/vllm/example/config.pbtxt new file mode 100644 index 00000000..396a38d4 --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/example/config.pbtxt @@ -0,0 +1,20 @@ +name: "text-generation" +max_batch_size: 1 +input { + name: "text" + data_type: TYPE_STRING + dims: 1 +} +output { + name: "text" + data_type: TYPE_STRING + dims: 1 +} +instance_group { + count: 1 + kind: KIND_GPU +} +dynamic_batching { + max_queue_delay_microseconds: 500 +} +backend: "python" diff --git a/clarifai/models/model_serving/examples/vllm/example/requirements.txt b/clarifai/models/model_serving/examples/vllm/example/requirements.txt new file mode 100644 index 00000000..bb6110e5 --- /dev/null +++ b/clarifai/models/model_serving/examples/vllm/example/requirements.txt @@ -0,0 +1,5 @@ +clarifai +tritonclient[all] +transformers==4.34.1 +torch==2.0.1 +vllm==0.2.1.post1 From 53b362eacb087e73e89b370f0b2db583a0c7721e Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Fri, 24 Nov 2023 17:24:46 +0700 Subject: [PATCH 2/8] init --- .../model_serving/models/default_test.py | 7 +- .../models/model_serving/models/inference.py | 2 +- .../model_serving/models/model_types.py | 71 +++++++++++-------- 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/clarifai/models/model_serving/models/default_test.py b/clarifai/models/model_serving/models/default_test.py index 0848f66c..84cfb382 100644 --- a/clarifai/models/model_serving/models/default_test.py +++ b/clarifai/models/model_serving/models/default_test.py @@ -155,7 +155,7 @@ def _is_integer(x): inputs = [self.preprocess["image"](inp) for inp in PREDEFINED_IMAGES] else: inputs = PREDEFINED_TEXTS - outputs = [self.triton_get_predictions(inp) for inp in inputs] + outputs = self.triton_get_predictions(inputs) # Test for specific model type: # 1. length of output array vs config @@ -259,9 +259,12 @@ def _is_integer(x): input_texts = PREDEFINED_TEXTS def _assert(input_data): + batch_inputs = [] for group in zip_longest(*input_data.values()): _input = dict(zip(input_data, group)) - output = self.triton_get_predictions(input_data=_input) + batch_inputs.append(_input) + outputs = self.triton_get_predictions(input_data=batch_inputs) + for output in outputs: self.assertEqual( type(output), EmbeddingOutput, f"Output type must be `EmbeddingOutput`, but got {type(output)}") diff --git a/clarifai/models/model_serving/models/inference.py b/clarifai/models/model_serving/models/inference.py index 3103affd..ca718e49 100644 --- a/clarifai/models/model_serving/models/inference.py +++ b/clarifai/models/model_serving/models/inference.py @@ -25,7 +25,7 @@ def __init__(self) -> None: #self.model: Callable = #Add relevant model type decorator to the method below (see docs/model_types for ref.) - def get_predictions(self, input_data, **kwargs): + def get_predictions(self, input_data: list, **kwargs): """ Main model inference method. diff --git a/clarifai/models/model_serving/models/model_types.py b/clarifai/models/model_serving/models/model_types.py index 8515caef..5e651f60 100644 --- a/clarifai/models/model_serving/models/model_types.py +++ b/clarifai/models/model_serving/models/model_types.py @@ -39,11 +39,12 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): out_bboxes = [] out_labels = [] out_scores = [] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_bboxes.append(preds.predicted_bboxes) - out_labels.append(preds.predicted_labels) - out_scores.append(preds.predicted_scores) + input_data = [each for each in input_data] + preds = func(self, input_data, *args, **kwargs) + for pred in preds: + out_bboxes.append(pred.predicted_bboxes) + out_labels.append(pred.predicted_labels) + out_scores.append(pred.predicted_scores) if len(out_bboxes) < 1 or len(out_labels) < 1: out_tensor_bboxes = pb_utils.Tensor("predicted_bboxes", np.zeros((0, 4), dtype=np.float32)) @@ -76,9 +77,11 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): Format predictions and return clarifai compatible output. """ out_scores = [] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_scores.append(preds.predicted_scores) + input_data = [each for each in input_data] + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + out_scores.append(pred.predicted_scores) out_tensor_scores = pb_utils.Tensor("softmax_predictions", np.asarray(out_scores, dtype=np.float32)) @@ -101,9 +104,10 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): """ out_scores = [] input_data = [in_elem[0].decode() for in_elem in input_data] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_scores.append(preds.predicted_scores) + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + out_scores.append(pred.predicted_scores) out_tensor_scores = pb_utils.Tensor("softmax_predictions", np.asarray(out_scores, dtype=np.float32)) @@ -128,9 +132,10 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): """ out_text = [] input_data = [in_elem[0].decode() for in_elem in input_data] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_text.append(preds.predicted_text) + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + out_text.append(pred.predicted_text) out_text_tensor = pb_utils.Tensor("text", np.asarray(out_text, dtype=object)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_text_tensor]) @@ -153,9 +158,10 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): """ out_embeddings = [] input_data = [in_elem[0].decode() for in_elem in input_data] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_embeddings.append(preds.embedding_vector) + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + out_embeddings.append(pred.embedding_vector) out_embed_tensor = pb_utils.Tensor("embeddings", np.asarray(out_embeddings, dtype=np.float32)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_embed_tensor]) @@ -177,9 +183,11 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): Format predictions and return clarifai compatible output. """ out_embeddings = [] - for item in input_data: - preds = func(self, item, *args, **kwargs) - out_embeddings.append(preds.embedding_vector) + input_data = [each for each in input_data] + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + out_embeddings.append(pred.embedding_vector) out_embed_tensor = pb_utils.Tensor("embeddings", np.asarray(out_embeddings, dtype=np.float32)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_embed_tensor]) @@ -200,9 +208,11 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): Format predictions and return clarifai compatible output. """ masks = [] - for item in input_data: - preds = func(self, item, *args, **kwargs) - masks.append(preds.predicted_mask) + input_data = [each for each in input_data] + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + masks.append(pred.predicted_mask) out_mask_tensor = pb_utils.Tensor("predicted_mask", np.asarray(masks, dtype=np.int64)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_mask_tensor]) @@ -224,9 +234,10 @@ def parse_predictions(self, input_data: np.ndarray, *args, **kwargs): """ gen_images = [] input_data = [in_elem[0].decode() for in_elem in input_data] - for item in input_data: - preds = func(self, item, *args, **kwargs) - gen_images.append(preds.image) + preds = func(self, input_data, *args, **kwargs) + + for pred in preds: + gen_images.append(pred.image) out_image_tensor = pb_utils.Tensor("image", np.asarray(gen_images, dtype=np.uint8)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_image_tensor]) @@ -248,14 +259,18 @@ def parse_predictions(self, input_data: Dict[str, np.ndarray], *args, **kwargs): Format predictions and return clarifai compatible output. """ out_embeddings = [] + model_input_data = [] for group in zip_longest(*input_data.values()): _input_data = dict(zip(input_data, group)) for k, v in _input_data.items(): # decode np.object to string if isinstance(v, np.ndarray) and v.dtype == np.object_: _input_data.update({k: v[0].decode()}) - preds = func(self, _input_data, *args, **kwargs) - out_embeddings.append(preds.embedding_vector) + model_input_data.append(_input_data) + + preds = func(self, model_input_data, *args, **kwargs) + for pred in preds: + out_embeddings.append(pred.embedding_vector) out_embed_tensor = pb_utils.Tensor("embeddings", np.asarray(out_embeddings, dtype=np.float32)) inference_response = pb_utils.InferenceResponse(output_tensors=[out_embed_tensor]) From 9de21622df789335502474dd72d46feaea21a812 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 00:26:33 +0700 Subject: [PATCH 3/8] update --- clarifai/models/api.py | 2 +- .../models/model_serving/cli/deploy_cli.py | 2 +- .../model_serving/model_config/config.py | 3 ++ .../models/model_serving/models/inference.py | 16 ++++-- .../models/model_serving/models/output.py | 3 +- .../model_serving/pb_model_repository.py | 54 +++++++++++-------- 6 files changed, 51 insertions(+), 29 deletions(-) diff --git a/clarifai/models/api.py b/clarifai/models/api.py index 3f476ad0..3182ce49 100644 --- a/clarifai/models/api.py +++ b/clarifai/models/api.py @@ -18,8 +18,8 @@ from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct, Value -from clarifai.auth.helper import ClarifaiAuthHelper from clarifai.client import create_stub +from clarifai.client.auth.helper import ClarifaiAuthHelper def _make_default_value_proto(dtype, value): diff --git a/clarifai/models/model_serving/cli/deploy_cli.py b/clarifai/models/model_serving/cli/deploy_cli.py index ccf83883..ae2a4bd6 100644 --- a/clarifai/models/model_serving/cli/deploy_cli.py +++ b/clarifai/models/model_serving/cli/deploy_cli.py @@ -13,7 +13,7 @@ """Commandline interface for model upload utils.""" import argparse -from clarifai.auth.helper import ClarifaiAuthHelper +from clarifai.client.auth.helper import ClarifaiAuthHelper from clarifai.models.api import Models from clarifai.models.model_serving.model_config import MODEL_TYPES, get_model_config from clarifai.models.model_serving.model_config.inference_parameter import InferParamManager diff --git a/clarifai/models/model_serving/model_config/config.py b/clarifai/models/model_serving/model_config/config.py index 7b836e29..40cd17ed 100644 --- a/clarifai/models/model_serving/model_config/config.py +++ b/clarifai/models/model_serving/model_config/config.py @@ -284,8 +284,11 @@ def get_model_config(model_type: str) -> ModelConfigClass: ModelConfigClass ### Example: + >>> from clarifai.models.model_serving.models.output import ClassifierOutput + >>> from clarifai.models.model_serving.model_config import get_model_config, ModelTypes >>> cfg = get_model_config(ModelTypes.text_classifier) >>> custom_triton_config = cfg.make_triton_model_config(**kwargs) + >>> cfg.inference.return_type is ClassifierOutput # True """ diff --git a/clarifai/models/model_serving/models/inference.py b/clarifai/models/model_serving/models/inference.py index ca718e49..28965aaa 100644 --- a/clarifai/models/model_serving/models/inference.py +++ b/clarifai/models/model_serving/models/inference.py @@ -10,6 +10,11 @@ import os from pathlib import Path +from clarifai.models.model_serving.model_config import ( # noqa # pylint: disable=unused-import + ModelTypes, get_model_config) + +#config = get_model_config("clarifai-model-type") + class InferenceModel: """User model inference class.""" @@ -24,19 +29,22 @@ def __init__(self) -> None: #self.checkpoint_path: Path = os.path.join(self.base_path, "your checkpoint filename/path") #self.model: Callable = - #Add relevant model type decorator to the method below (see docs/model_types for ref.) - def get_predictions(self, input_data: list, **kwargs): + #@config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ + # Delete/Comment out line below and add your inference code raise NotImplementedError() diff --git a/clarifai/models/model_serving/models/output.py b/clarifai/models/model_serving/models/output.py index 57624012..012ad0ec 100644 --- a/clarifai/models/model_serving/models/output.py +++ b/clarifai/models/model_serving/models/output.py @@ -67,12 +67,13 @@ class TextOutput: """ Takes model text predictions """ - predicted_text: np.ndarray + predicted_text: str def __post_init__(self): """ Validate input upon initialization. """ + self.predicted_text = np.array(self.predicted_text, dtype=object) assert self.predicted_text.ndim == 0, \ f"All predictions must be 0-dimensional, Got text-dims: {self.predicted_text.ndim} instead." diff --git a/clarifai/models/model_serving/pb_model_repository.py b/clarifai/models/model_serving/pb_model_repository.py index dcb9406a..e42cd6e8 100644 --- a/clarifai/models/model_serving/pb_model_repository.py +++ b/clarifai/models/model_serving/pb_model_repository.py @@ -17,7 +17,7 @@ import inspect import os from pathlib import Path -from typing import Callable, Type +from typing import Type from .model_config import Serializer, TritonModelConfig from .models import inference, pb_model, test @@ -32,24 +32,26 @@ def __init__(self, model_config: Type[TritonModelConfig]): self.model_config = model_config self.config_proto = Serializer(model_config) - def _module_to_file(self, module_name: Callable, filename: str, destination_dir: str) -> None: + def _module_to_file(self, module, file_path: str, func: callable = None): """ Write Python Module to file. Args: ----- - module_name: Python module name to write to file - filename: Name of the file to write to destination_dir - destination_dir: Directory to save the generated triton model file. - + module_name: Python module name to write to file + file_path: Path of file to write module code into. + func: A function to process code of module. It contains only 1 argument, text of module. If it is None, then only save text to `file_path` Returns: -------- - None + None """ - module_path: Path = os.path.join(destination_dir, filename) - source_code = inspect.getsource(module_name) - with open(module_path, "w") as pb_model: - pb_model.write(source_code) + source_code = inspect.getsource(module) + with open(file_path, "w") as fp: + # change model type + if func: + source_code = func(source_code) + # write it to file + fp.write(source_code) def build_repository(self, repository_dir: Path = os.curdir): """ @@ -80,22 +82,30 @@ def build_repository(self, repository_dir: Path = os.curdir): continue # gen requirements with open(os.path.join(repository_path, "requirements.txt"), "w") as f: - f.write("clarifai>9.5.3\ntritonclient[all]") # for model upload utils + f.write("clarifai>9.10.3\ntritonclient[all]") # for model upload utils if not os.path.isdir(model_version_path): os.mkdir(model_version_path) if not os.path.exists(os.path.join(model_version_path, "__init__.py")): with open(os.path.join(model_version_path, "__init__.py"), "w"): pass - # generate model.py & inference.py modules - self._module_to_file(pb_model, filename="model.py", destination_dir=model_version_path) - self._module_to_file(inference, filename="inference.py", destination_dir=model_version_path) + # generate model.py + model_py_path = os.path.join(model_version_path, "model.py") + self._module_to_file(pb_model, model_py_path, func=None) + + # generate inference.py + def inference_insert_model_type_func(x): + x = x.replace("""#config = get_model_config("clarifai-model-type")""", + f"""config = get_model_config("{self.model_config.model_type}")""") + x = x.replace("#@config.inference.wrap_func", "@config.inference.wrap_func") + return x + + inference_py_path = os.path.join(model_version_path, "inference.py") + self._module_to_file(inference, inference_py_path, inference_insert_model_type_func) + # generate test.py + def insert_model_type_func(x): + return x.replace("clarifai-model-type", self.model_config.model_type) + custom_test_path = os.path.join(model_version_path, "test.py") - test_source_code = inspect.getsource(test) - with open(custom_test_path, "w") as fp: - # change model type - test_source_code = test_source_code.replace("clarifai-model-type", - self.model_config.model_type) - # write it to file - fp.write(test_source_code) + self._module_to_file(test, custom_test_path, insert_model_type_func) From 4c88e5b81b643c999ce2f34fa6c7a0557a2f7036 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 12:33:39 +0700 Subject: [PATCH 4/8] update examples --- .../examples/image_classification/README.md | 5 +- .../age_vit/1/inference.py | 32 +++++--- .../image_classification/age_vit/1/model.py | 23 ++++-- .../age_vit/requirements.txt | 2 +- .../examples/multimodal_embedder/README.md | 12 +++ .../multimodal_embedder/clip/1/__init__.py | 0 .../multimodal_embedder/clip/1/inference.py | 66 +++++++++++++++++ .../multimodal_embedder/clip/1/model.py | 74 +++++++++++++++++++ .../multimodal_embedder/clip/1/test.py | 64 ++++++++++++++++ .../multimodal_embedder/clip/config.pbtxt | 29 ++++++++ .../multimodal_embedder/clip/requirements.txt | 4 + .../examples/text_classification/README.md | 5 +- .../xlm-roberta/1/inference.py | 31 +++++--- .../xlm-roberta/1/model.py | 23 ++++-- .../xlm-roberta/config.pbtxt | 2 +- .../xlm-roberta/requirements.txt | 2 +- .../examples/text_embedding/README.md | 3 + .../instructor-xl/1/__init__.py | 0 .../instructor-xl/1/inference.py | 63 ++++++++++++++++ .../text_embedding/instructor-xl/1/model.py | 74 +++++++++++++++++++ .../text_embedding/instructor-xl/1/test.py | 64 ++++++++++++++++ .../text_embedding/instructor-xl/config.pbtxt | 20 +++++ .../instructor-xl/requirements.txt | 9 +++ .../examples/text_to_image/README.md | 7 +- .../text_to_image/sd-v1.5/1/inference.py | 20 +++-- .../examples/text_to_image/sd-v1.5/1/model.py | 22 +++++- .../text_to_image/sd-v1.5/requirements.txt | 2 +- .../examples/text_to_text/README.md | 6 +- .../bart-summarize/1/inference.py | 34 ++++++--- .../text_to_text/bart-summarize/1/model.py | 22 +++++- .../bart-summarize/requirements.txt | 2 +- .../examples/visual_detection/README.md | 6 +- .../visual_detection/yolov5x/1/inference.py | 55 ++++++++------ .../visual_detection/yolov5x/1/model.py | 23 ++++-- .../visual_detection/yolov5x/config.pbtxt | 2 +- .../visual_detection/yolov5x/requirements.txt | 2 +- .../examples/visual_embedding/README.md | 5 +- .../visual_embedding/vit-base/1/inference.py | 15 ++-- .../visual_embedding/vit-base/1/model.py | 22 +++++- .../vit-base/requirements.txt | 2 +- .../examples/visual_segmentation/README.md | 5 +- .../segformer-b2/1/inference.py | 25 ++++--- .../segformer-b2/1/model.py | 22 +++++- .../segformer-b2/requirements.txt | 2 +- .../examples/vllm/example/1/inference.py | 11 ++- 45 files changed, 784 insertions(+), 135 deletions(-) create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/README.md create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/1/__init__.py create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/1/inference.py create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/1/model.py create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/1/test.py create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/config.pbtxt create mode 100644 clarifai/models/model_serving/examples/multimodal_embedder/clip/requirements.txt create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/__init__.py create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/inference.py create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/model.py create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/test.py create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/config.pbtxt create mode 100644 clarifai/models/model_serving/examples/text_embedding/instructor-xl/requirements.txt diff --git a/clarifai/models/model_serving/examples/image_classification/README.md b/clarifai/models/model_serving/examples/image_classification/README.md index 67ca11fa..85e85c4c 100644 --- a/clarifai/models/model_serving/examples/image_classification/README.md +++ b/clarifai/models/model_serving/examples/image_classification/README.md @@ -6,4 +6,7 @@ These can be used on the fly with minimal or no changes to test deploy image cla Required files to run tests locally: - * Download the [model checkpoint from huggingface](https://huggingface.co/nateraw/vit-age-classifier/tree/main) and store it under `age_vit/1/vit-age-classifier/` + * Download the [model checkpoint from huggingface](https://huggingface.co/nateraw/vit-age-classifier/tree/main) and store it under `age_vit/1/checkpoint/` + ``` + huggingface-cli download nateraw/vit-age-classifier --local-dir age_vit/1/checkpoint/ --local-dir-use-symlinks False + ``` diff --git a/clarifai/models/model_serving/examples/image_classification/age_vit/1/inference.py b/clarifai/models/model_serving/examples/image_classification/age_vit/1/inference.py index d05e3b0d..91bc4953 100644 --- a/clarifai/models/model_serving/examples/image_classification/age_vit/1/inference.py +++ b/clarifai/models/model_serving/examples/image_classification/age_vit/1/inference.py @@ -13,11 +13,13 @@ import torch from scipy.special import softmax -from transformers import ViTFeatureExtractor, ViTForImageClassification +from transformers import AutoImageProcessor, ViTForImageClassification -from clarifai.models.model_serving.models.model_types import visual_classifier +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import ClassifierOutput +config = get_model_config(ModelTypes.visual_classifier) + class InferenceModel: """User model inference class.""" @@ -28,29 +30,35 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.huggingface_model_path: Path = os.path.join(self.base_path, "vit-age-classifier") - self.transforms = ViTFeatureExtractor.from_pretrained(self.huggingface_model_path) + self.huggingface_model_path: Path = os.path.join(self.base_path, "checkpoint") + self.transforms = AutoImageProcessor.from_pretrained(self.huggingface_model_path) self.model: Callable = ViTForImageClassification.from_pretrained(self.huggingface_model_path) self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - @visual_classifier - def get_predictions(self, input_data) -> ClassifierOutput: + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ # Transform image and pass it to the model inputs = self.transforms(input_data, return_tensors='pt') - output = self.model(**inputs) - pred_scores = softmax( - output[0][0].detach().numpy()) # alt: softmax(output.logits[0].detach().numpy()) + with torch.no_grad(): + preds = self.model(**inputs).logits + outputs = [] + for pred in preds: + pred_scores = softmax( + pred.detach().numpy()) # alt: softmax(output.logits[0].detach().numpy()) + outputs.append(ClassifierOutput(predicted_scores=pred_scores)) - return ClassifierOutput(predicted_scores=pred_scores) + return outputs diff --git a/clarifai/models/model_serving/examples/image_classification/age_vit/1/model.py b/clarifai/models/model_serving/examples/image_classification/age_vit/1/model.py index 1cacacd2..36b54b37 100644 --- a/clarifai/models/model_serving/examples/image_classification/age_vit/1/model.py +++ b/clarifai/models/model_serving/examples/image_classification/age_vit/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -37,14 +38,13 @@ def initialize(self, args): from inference import InferenceModel self.inference_obj = InferenceModel() - self.device = "cuda:0" if "GPU" in args["model_instance_kind"] else "cpu" # Read input_name from config file self.config_msg = ModelConfig() with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -53,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt b/clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt index 4552c2a7..7379fba3 100644 --- a/clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt +++ b/clarifai/models/model_serving/examples/image_classification/age_vit/requirements.txt @@ -1,4 +1,4 @@ -clarifai>9.5.3 # for model upload features +clarifai>9.10.5 tritonclient[all] torch==1.13.1 transformers==4.30.2 diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/README.md b/clarifai/models/model_serving/examples/multimodal_embedder/README.md new file mode 100644 index 00000000..32779373 --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/README.md @@ -0,0 +1,12 @@ +## Image Classification Triton Model Examples + +These can be used on the fly with minimal or no changes to test deploy image classification models to the Clarifai platform. See the required files section for each model below. + +* ### [VIT Age Classifier](./clip/) + + Required files to run tests locally: + + * Download the [model checkpoint from huggingface](https://huggingface.co/openai/clip-vit-base-patch32) and store it under `clip/1/checkpoint/` + ``` + huggingface-cli download openai/clip-vit-base-patch32 --local-dir clip/1/checkpoint/ --local-dir-use-symlinks False --exclude *.msgpack *.h5 + ``` diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/__init__.py b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/inference.py b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/inference.py new file mode 100644 index 00000000..a8baf099 --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/inference.py @@ -0,0 +1,66 @@ +# This file contains boilerplate code to allow users write their model +# inference code that will then interact with the Triton Inference Server +# Python backend to serve end user requests. +# The module name, module path, class name & get_predictions() method names MUST be maintained as is +# but other methods may be added within the class as deemed fit provided +# they are invoked within the main get_predictions() inference method +# if they play a role in any step of model inference +"""User model inference script.""" + +import os +from pathlib import Path + +import torch +from transformers import CLIPModel, CLIPProcessor +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config + +config = get_model_config(ModelTypes.multimodal_embedder) + + +class InferenceModel: + """User model inference class.""" + + def __init__(self) -> None: + """ + Load inference time artifacts that are called frequently .e.g. models, tokenizers, etc. + in this method so they are loaded only once for faster inference. + """ + self.base_path: Path = os.path.dirname(__file__) + ## sample model loading code: + #self.checkpoint_path: Path = os.path.join(self.base_path, "your checkpoint filename/path") + #self.model: Callable = + self.model = CLIPModel.from_pretrained(os.path.join(self.base_path, "checkpoint")) + self.model.eval() + #self.text_model = CLIPTextModel.from_pretrained(os.path.join(self.base_path, "openai/clip-vit-base-patch32")) + self.processor = CLIPProcessor.from_pretrained(os.path.join(self.base_path, "checkpoint")) + + #Add relevant model type decorator to the method below (see docs/model_types for ref.) + @config.inference.wrap_func + def get_predictions(self, input_data, **kwargs): + """ + Main model inference method. + + Args: + ----- + input_data: A single input data item to predict on. + Input data can be an image or text, etc depending on the model type. + + Returns: + -------- + One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + """ + outputs = [] + for inp in input_data: + image, text = inp["image"], inp["text"] + with torch.no_grad(): + inputs = self.processor(text=text, images=image, return_tensors="pt", padding=True) + if text is not None: + inputs = self.processor(text=text, return_tensors="pt", padding=True) + embeddings = self.model.get_text_features(**inputs) + else: + inputs = self.processor(images=image, return_tensors="pt", padding=True) + embeddings = self.model.get_image_features(**inputs) + embeddings = embeddings.squeeze().cpu().numpy() + outputs.append(config.inference.return_type(embedding_vector=embeddings)) + + return outputs diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/model.py b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/model.py new file mode 100644 index 00000000..36b54b37 --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/model.py @@ -0,0 +1,74 @@ +# Copyright 2023 Clarifai, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Triton inference server Python Backend Model.""" + +import os +import sys + +try: + import triton_python_backend_utils as pb_utils +except ModuleNotFoundError: + pass +from google.protobuf import text_format +from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters + + +class TritonPythonModel: + """ + Triton Python BE Model. + """ + + def initialize(self, args): + """ + Triton server init. + """ + args["model_repository"] = args["model_repository"].replace("/1/model.py", "") + sys.path.append(os.path.dirname(__file__)) + from inference import InferenceModel + + self.inference_obj = InferenceModel() + + # Read input_name from config file + self.config_msg = ModelConfig() + with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: + cfg = f.read() + text_format.Merge(cfg, self.config_msg) + self.input_names = [inp.name for inp in self.config_msg.input] + + def execute(self, requests): + """ + Serve model inference requests. + """ + responses = [] + + for request in requests: + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + + responses.append(inference_response) + + return responses diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/test.py b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/test.py new file mode 100644 index 00000000..1874c676 --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/clip/1/test.py @@ -0,0 +1,64 @@ +import logging +import os +import unittest + +from clarifai.models.model_serving.models.default_test import DefaultTestInferenceModel + + +class CustomTestInferenceModel(DefaultTestInferenceModel): + """ + Run this file to test your implementation of InferenceModel in inference.py with default tests of Triton configuration and its output values based on basic predefined inputs + If you want to write custom testcase or just test output value. + Please follow these instrucitons: + 1. Name your test function with prefix "test" so that pytest can execute + 2. In order to obtain output of InferenceModel, call `self.triton_get_predictions(input_data)`. + 3. If your input is `image` and you have set custom size of it when building model repository, + call `self.preprocess(image)` to obtain correct resized input + 4. Run this test by calling + ```bash + pytest ./your_triton_folder/1/test.py + #to see std output + pytest --log-cli-level=INFO -s ./your_triton_folder/1/test.py + ``` + + ### Examples: + + test text-to-image output + ``` + def test_text_to_image_output(self): + text = "Test text" + output = self.triton_get_predictions(text) + image = output.image # uint8 np.ndarray image + #show or save + ``` + + test visual-classifier output + ``` + def test_visual_classifier(self): + image = cv2.imread("your/local/image.jpg") # Keep in mind of format of image (BGR or RGB) + output = self.triton_get_predictions(image) + scores = output.predicted_scores # np.ndarray + #process scores to get class id and its score + logger.info(result) + """ + + # Insert your inference parameters json path here + # or insert a dictionary of your_parameter_name and value, e.g dict(x=1.5, y="text", c=True) + # or Leave it as "" if you don't have it. + inference_parameters = "" + + ########### Initialization. Do not change it ########### + __test__ = True + + def setUp(self) -> None: + logging.info("Initializing...") + model_type = "multimodal-embedder" # your model type + self.intitialize( + model_type, + repo_version_dir=os.path.dirname(__file__), + is_instance_kind_gpu=True, + inference_parameters=self.inference_parameters) + + ######################################################## + + +if __name__ == '__main__': + unittest.main() diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/config.pbtxt b/clarifai/models/model_serving/examples/multimodal_embedder/clip/config.pbtxt new file mode 100644 index 00000000..6878506a --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/clip/config.pbtxt @@ -0,0 +1,29 @@ +name: "clip" +max_batch_size: 1 +input { + name: "image" + data_type: TYPE_UINT8 + dims: -1 + dims: -1 + dims: 3 + optional: true +} +input { + name: "text" + data_type: TYPE_STRING + dims: 1 + optional: true +} +output { + name: "embeddings" + data_type: TYPE_FP32 + dims: -1 +} +instance_group { + count: 1 + kind: KIND_GPU +} +dynamic_batching { + max_queue_delay_microseconds: 500 +} +backend: "python" diff --git a/clarifai/models/model_serving/examples/multimodal_embedder/clip/requirements.txt b/clarifai/models/model_serving/examples/multimodal_embedder/clip/requirements.txt new file mode 100644 index 00000000..8b05c758 --- /dev/null +++ b/clarifai/models/model_serving/examples/multimodal_embedder/clip/requirements.txt @@ -0,0 +1,4 @@ +clarifai>9.10.5 +tritonclient[all] +transformers==4.34.1 +torch==2.0.1 diff --git a/clarifai/models/model_serving/examples/text_classification/README.md b/clarifai/models/model_serving/examples/text_classification/README.md index c5a1f9a8..b3c3555d 100644 --- a/clarifai/models/model_serving/examples/text_classification/README.md +++ b/clarifai/models/model_serving/examples/text_classification/README.md @@ -6,4 +6,7 @@ These can be used on the fly with minimal or no changes to test deploy text clas Required files to run tests locally: - * Download the [model checkpoint & sentencepiece bpe model from huggingface](https://huggingface.co/cardiffnlp/twitter-xlm-roberta-base-sentiment/tree/main) and store it under `xlm-roberta/1/twitter-xlm-roberta-base-sentiment/` + * Download the [model checkpoint](https://huggingface.co/cardiffnlp/twitter-xlm-roberta-base-sentiment/tree/main) and store it under `xlm-roberta/1/checkpoint/` + ``` + huggingface-cli download cardiffnlp/twitter-xlm-roberta-base-sentiment --local-dir xlm-roberta/1/checkpoint/ --local-dir-use-symlinks False + ``` diff --git a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/inference.py b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/inference.py index 952dc64d..43d2cf67 100644 --- a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/inference.py +++ b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/inference.py @@ -15,9 +15,11 @@ from scipy.special import softmax from transformers import AutoModelForSequenceClassification, AutoTokenizer -from clarifai.models.model_serving.models.model_types import text_classifier +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import ClassifierOutput +config = get_model_config(ModelTypes.text_classifier) + class InferenceModel: """User model inference class.""" @@ -28,28 +30,33 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.checkpoint_path: Path = os.path.join(self.base_path, "twitter-xlm-roberta-base-sentiment") + self.checkpoint_path: Path = os.path.join(self.base_path, "checkpoint") self.model: Callable = AutoModelForSequenceClassification.from_pretrained(self.checkpoint_path) self.tokenizer: Callable = AutoTokenizer.from_pretrained(self.checkpoint_path) self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - @text_classifier - def get_predictions(self, input_data) -> ClassifierOutput: + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ - encoded_input = self.tokenizer(input_data, return_tensors='pt') - output = self.model(**encoded_input) - scores = output[0][0].detach().numpy() - scores = softmax(scores) - - return ClassifierOutput(predicted_scores=scores) + outputs = [] + for inp in input_data: + encoded_input = self.tokenizer(inp, return_tensors='pt') + output = self.model(**encoded_input) + scores = output[0][0].detach().numpy() + scores = softmax(scores) + outputs.append(ClassifierOutput(predicted_scores=scores)) + + return outputs diff --git a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/model.py b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/model.py index 1cacacd2..36b54b37 100644 --- a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/model.py +++ b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -37,14 +38,13 @@ def initialize(self, args): from inference import InferenceModel self.inference_obj = InferenceModel() - self.device = "cuda:0" if "GPU" in args["model_instance_kind"] else "cpu" # Read input_name from config file self.config_msg = ModelConfig() with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -53,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/config.pbtxt b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/config.pbtxt index 75d3ddeb..136be683 100644 --- a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/config.pbtxt +++ b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/config.pbtxt @@ -1,5 +1,5 @@ name: "xlm-roberta" -max_batch_size: 1 +max_batch_size: 2 input { name: "text" data_type: TYPE_STRING diff --git a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt index 4552c2a7..7379fba3 100644 --- a/clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt +++ b/clarifai/models/model_serving/examples/text_classification/xlm-roberta/requirements.txt @@ -1,4 +1,4 @@ -clarifai>9.5.3 # for model upload features +clarifai>9.10.5 tritonclient[all] torch==1.13.1 transformers==4.30.2 diff --git a/clarifai/models/model_serving/examples/text_embedding/README.md b/clarifai/models/model_serving/examples/text_embedding/README.md index 0482ce72..0ee8e4ec 100644 --- a/clarifai/models/model_serving/examples/text_embedding/README.md +++ b/clarifai/models/model_serving/examples/text_embedding/README.md @@ -7,3 +7,6 @@ These can be used on the fly with minimal or no changes to test deploy text embe Requirements to run tests locally: * Download/Clone the [huggingface model](https://huggingface.co/hkunlp/instructor-xl) into the **instructor-xl/1/** directory then start the triton server. + ``` + huggingface-cli download hkunlp/instructor-xl --local-dir instructor-xl/1/checkpoint/sentence_transformers/hkunlp_instructor-xl --local-dir-use-symlinks False + ``` diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/__init__.py b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/inference.py b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/inference.py new file mode 100644 index 00000000..e9130842 --- /dev/null +++ b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/inference.py @@ -0,0 +1,63 @@ +# This file contains boilerplate code to allow users write their model +# inference code that will then interact with the Triton Inference Server +# Python backend to serve end user requests. +# The module name, module path, class name & get_predictions() method names MUST be maintained as is +# but other methods may be added within the class as deemed fit provided +# they are invoked within the main get_predictions() inference method +# if they play a role in any step of model inference +"""User model inference script.""" + +import os +from pathlib import Path + +# Set up env for huggingface +ROOT_PATH = os.path.join(os.path.dirname(__file__)) +PIPELINE_PATH = os.path.join(ROOT_PATH, 'checkpoint') + +os.environ["TRANSFORMERS_OFFLINE"] = "1" # noqa +os.environ['TRANSFORMERS_CACHE'] = PIPELINE_PATH # noqa +os.environ['TORCH_HOME'] = PIPELINE_PATH + +import torch # noqa +from InstructorEmbedding import INSTRUCTOR # noqa + +from clarifai.models.model_serving.model_config import ( # noqa # pylint: disable=unused-import + ModelTypes, get_model_config) + +config = get_model_config("text-embedder") + + +class InferenceModel: + """User model inference class.""" + + def __init__(self) -> None: + """ + Load inference time artifacts that are called frequently .e.g. models, tokenizers, etc. + in this method so they are loaded only once for faster inference. + """ + self.base_path: Path = os.path.dirname(__file__) + ## sample model loading code: + #self.checkpoint_path: Path = os.path.join(self.base_path, "your checkpoint filename/path") + #self.model: Callable = + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.model = INSTRUCTOR('hkunlp/instructor-xl') + + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: + """ + Main model inference method. + + Args: + ----- + input_data: A list of input data item to predict on. + Input data can be an image or text, etc depending on the model type. + + **kwargs: your inference parameters. + + Returns: + -------- + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs + """ + batch_preds = self.model.encode(input_data, device=self.device) + + return [config.inference.return_type(each) for each in batch_preds] diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/model.py b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/model.py new file mode 100644 index 00000000..36b54b37 --- /dev/null +++ b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/model.py @@ -0,0 +1,74 @@ +# Copyright 2023 Clarifai, Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Triton inference server Python Backend Model.""" + +import os +import sys + +try: + import triton_python_backend_utils as pb_utils +except ModuleNotFoundError: + pass +from google.protobuf import text_format +from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters + + +class TritonPythonModel: + """ + Triton Python BE Model. + """ + + def initialize(self, args): + """ + Triton server init. + """ + args["model_repository"] = args["model_repository"].replace("/1/model.py", "") + sys.path.append(os.path.dirname(__file__)) + from inference import InferenceModel + + self.inference_obj = InferenceModel() + + # Read input_name from config file + self.config_msg = ModelConfig() + with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: + cfg = f.read() + text_format.Merge(cfg, self.config_msg) + self.input_names = [inp.name for inp in self.config_msg.input] + + def execute(self, requests): + """ + Serve model inference requests. + """ + responses = [] + + for request in requests: + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + + responses.append(inference_response) + + return responses diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/test.py b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/test.py new file mode 100644 index 00000000..5c8bcd90 --- /dev/null +++ b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/1/test.py @@ -0,0 +1,64 @@ +import logging +import os +import unittest + +from clarifai.models.model_serving.models.default_test import DefaultTestInferenceModel + + +class CustomTestInferenceModel(DefaultTestInferenceModel): + """ + Run this file to test your implementation of InferenceModel in inference.py with default tests of Triton configuration and its output values based on basic predefined inputs + If you want to write custom testcase or just test output value. + Please follow these instrucitons: + 1. Name your test function with prefix "test" so that pytest can execute + 2. In order to obtain output of InferenceModel, call `self.triton_get_predictions(input_data)`. + 3. If your input is `image` and you have set custom size of it when building model repository, + call `self.preprocess(image)` to obtain correct resized input + 4. Run this test by calling + ```bash + pytest ./your_triton_folder/1/test.py + #to see std output + pytest --log-cli-level=INFO -s ./your_triton_folder/1/test.py + ``` + + ### Examples: + + test text-to-image output + ``` + def test_text_to_image_output(self): + text = "Test text" + output = self.triton_get_predictions(text) + image = output.image # uint8 np.ndarray image + #show or save + ``` + + test visual-classifier output + ``` + def test_visual_classifier(self): + image = cv2.imread("your/local/image.jpg") # Keep in mind of format of image (BGR or RGB) + output = self.triton_get_predictions(image) + scores = output.predicted_scores # np.ndarray + #process scores to get class id and its score + logger.info(result) + """ + + # Insert your inference parameters json path here + # or insert a dictionary of your_parameter_name and value, e.g dict(x=1.5, y="text", c=True) + # or Leave it as "" if you don't have it. + inference_parameters = "" + + ########### Initialization. Do not change it ########### + __test__ = True + + def setUp(self) -> None: + logging.info("Initializing...") + model_type = "text-embedder" # your model type + self.intitialize( + model_type, + repo_version_dir=os.path.dirname(__file__), + is_instance_kind_gpu=True, + inference_parameters=self.inference_parameters) + + ######################################################## + + +if __name__ == '__main__': + unittest.main() diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/config.pbtxt b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/config.pbtxt new file mode 100644 index 00000000..ddcc5e64 --- /dev/null +++ b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/config.pbtxt @@ -0,0 +1,20 @@ +name: "instructor-xl" +max_batch_size: 1 +input { + name: "text" + data_type: TYPE_STRING + dims: 1 +} +output { + name: "embeddings" + data_type: TYPE_FP32 + dims: -1 +} +instance_group { + count: 1 + kind: KIND_GPU +} +dynamic_batching { + max_queue_delay_microseconds: 500 +} +backend: "python" diff --git a/clarifai/models/model_serving/examples/text_embedding/instructor-xl/requirements.txt b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/requirements.txt new file mode 100644 index 00000000..39122e53 --- /dev/null +++ b/clarifai/models/model_serving/examples/text_embedding/instructor-xl/requirements.txt @@ -0,0 +1,9 @@ +clarifai>9.10.5 +tritonclient[all] +torch==1.13.1 +accelerate==0.20.3 +transformers==4.30.1 +scipy==1.10.1 +einops==0.6.1 +InstructorEmbedding==1.0.1 +sentence_transformers>=2.2.0 diff --git a/clarifai/models/model_serving/examples/text_to_image/README.md b/clarifai/models/model_serving/examples/text_to_image/README.md index d3f73bbf..d3ca8e45 100644 --- a/clarifai/models/model_serving/examples/text_to_image/README.md +++ b/clarifai/models/model_serving/examples/text_to_image/README.md @@ -4,6 +4,7 @@ These can be used on the fly with minimal or no changes to test deploy text to i * ### [sd-v1.5 (Stable-Diffusion-v1.5)](./sd-v1.5/) - Requirements to run tests locally: - - * Download/Clone the [huggingface model](https://huggingface.co/runwayml/stable-diffusion-v1-5) into the **sd-v1.5/1/** directory then start the triton server. + * Download the [model checkpoint](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) and store it under `sd-v1.5/1/checkpoint` + ``` + huggingface-cli download runwayml/stable-diffusion-v1-5 --local-dir sd-v1.5/1/checkpoint --local-dir-use-symlinks False + ``` diff --git a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/inference.py b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/inference.py index d89cd26d..ab6dab42 100644 --- a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/inference.py +++ b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/inference.py @@ -14,9 +14,11 @@ import torch from diffusers import StableDiffusionPipeline -from clarifai.models.model_serving.models.model_types import text_to_image +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import ImageOutput +config = get_model_config(ModelTypes.text_to_image) + class InferenceModel: """User model inference class.""" @@ -27,14 +29,14 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.huggingface_model_path = os.path.join(self.base_path, "stable-diffusion-v1-5") + self.huggingface_model_path = os.path.join(self.base_path, "checkpoint") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.pipeline = StableDiffusionPipeline.from_pretrained( self.huggingface_model_path, torch_dtype=torch.float16) self.pipeline = self.pipeline.to(self.device) - @text_to_image - def get_predictions(self, input_data): + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs): """ Main model inference method. @@ -47,6 +49,10 @@ def get_predictions(self, input_data): -------- One of the clarifai.models.model_serving.models.output types. Refer to the README/docs """ - out_image = self.pipeline(input_data).images[0] - out_image = np.asarray(out_image) - return ImageOutput(image=out_image) + outputs = [] + for inp in input_data: + out_image = self.pipeline(inp).images[0] + out_image = np.asarray(out_image) + outputs.append(ImageOutput(image=out_image)) + + return outputs diff --git a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/model.py b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/model.py index 5c9a230f..36b54b37 100644 --- a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/model.py +++ b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -43,7 +44,7 @@ def initialize(self, args): with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -52,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt index a8671dfd..5aa055a5 100644 --- a/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt +++ b/clarifai/models/model_serving/examples/text_to_image/sd-v1.5/requirements.txt @@ -1,4 +1,4 @@ -clarifai>9.5.3 +clarifai>9.10.5 tritonclient[all] torch==1.13.1 transformers==4.30.2 diff --git a/clarifai/models/model_serving/examples/text_to_text/README.md b/clarifai/models/model_serving/examples/text_to_text/README.md index 054fc894..dd3bc3ad 100644 --- a/clarifai/models/model_serving/examples/text_to_text/README.md +++ b/clarifai/models/model_serving/examples/text_to_text/README.md @@ -6,5 +6,7 @@ These can be used on the fly with minimal or no changes to test deploy all model Requirements to run tests locally: - * Download/Clone the [huggingface model](https://huggingface.co/com3dian/Bart-large-paper2slides-summarizer) and store it under the **bart-summarize/1/** directory. - * Rename the downloaded folder to **bart-large-summarizer** OR change the **self.huggingface_model_path** attribute in the [inference.py script](./bart-summarize/1/inference.py) to match the folder name + * Download/Clone the [huggingface model](https://huggingface.co/com3dian/Bart-large-paper2slides-summarizer) and store it under the **bart-summarize/1/checkpoint** directory. + ``` + huggingface-cli download com3dian/Bart-large-paper2slides-summarizer --local-dir bart-summarize/1/checkpoint --local-dir-use-symlinks False --exclude *.safetensors + ``` diff --git a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/inference.py b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/inference.py index fb5d412e..946c8989 100644 --- a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/inference.py +++ b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/inference.py @@ -9,12 +9,14 @@ import os from pathlib import Path -import numpy as np + from transformers import pipeline -from clarifai.models.model_serving.models.model_types import text_to_text +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import TextOutput +config = get_model_config(ModelTypes.text_to_text) + class InferenceModel: """User model inference class.""" @@ -25,23 +27,33 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.huggingface_model_path = os.path.join(self.base_path, "bart-large-summarizer") + self.huggingface_model_path = os.path.join(self.base_path, "checkpoint") self.pipeline = pipeline("summarization", model=self.huggingface_model_path) - @text_to_text - def get_predictions(self, input_data): + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ - Generates summaries of input text. + Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ - summary = self.pipeline(input_data, max_length=50, min_length=30, do_sample=False) - generated_text = np.array([summary[0]['summary_text']], dtype=object) - return TextOutput(predicted_text=generated_text) + # convert to top_k to int + outputs = [] + top_k = int(kwargs.get("top_k", 50)) + kwargs.pop("top_k", top_k) + + summaries = self.pipeline(input_data, **kwargs) + for summary in summaries: + generated_text = summary['summary_text'] + outputs.append(TextOutput(predicted_text=generated_text)) + + return outputs diff --git a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/model.py b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/model.py index 5c9a230f..36b54b37 100644 --- a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/model.py +++ b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -43,7 +44,7 @@ def initialize(self, args): with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -52,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt index 70d06fbc..956d2414 100644 --- a/clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt +++ b/clarifai/models/model_serving/examples/text_to_text/bart-summarize/requirements.txt @@ -1,4 +1,4 @@ -clarifai>9.5.3 +clarifai>9.10.5 tritonclient[all] torch==1.13.1 transformers==4.30.2 diff --git a/clarifai/models/model_serving/examples/visual_detection/README.md b/clarifai/models/model_serving/examples/visual_detection/README.md index 9c51587e..67287b6c 100644 --- a/clarifai/models/model_serving/examples/visual_detection/README.md +++ b/clarifai/models/model_serving/examples/visual_detection/README.md @@ -6,6 +6,10 @@ These can be used on the fly with minimal or no changes to test deploy visual de Required files (not included here due to upload size limits): - * Download the yolov5x folder from above. * Download the `Yolov5 repo` and the `yolov5-x checkpoint` and store them under the `1/` directory of the yolov5x folder. + ``` + cd yolov5x/1/ + git clone https://github.com/ultralytics/yolov5.git + wget -O model.pt https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5x.pt + ``` * zip and test deploy to your Clarifai app diff --git a/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/inference.py b/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/inference.py index 95788018..12339932 100644 --- a/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/inference.py +++ b/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/inference.py @@ -14,9 +14,11 @@ import numpy as np import torch -from clarifai.models.model_serving.models.model_types import visual_detector +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import VisualDetectorOutput +config = get_model_config(ModelTypes.visual_detector) + class InferenceModel: """User model inference class.""" @@ -36,37 +38,44 @@ def __init__(self) -> None: source='local') self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - @visual_detector - def get_predictions(self, input_data) -> VisualDetectorOutput: + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ - preds = self.model(input_data) max_bbox_count = 300 # max allowed detected bounding boxes per image - preds = preds.xyxy[0].cpu().numpy() - labels = [[pred[5]] for pred in preds] - scores = [[pred[4]] for pred in preds] - h, w, _ = input_data.shape # input image shape - bboxes = [[x[1] / h, x[0] / w, x[3] / h, x[2] / w] - for x in preds] # normalize the bboxes to [0,1] - if len(bboxes) != 0: - bboxes = np.concatenate((bboxes, np.zeros((max_bbox_count - len(bboxes), 4)))) - scores = np.concatenate((scores, np.zeros((max_bbox_count - len(scores), 1)))) - labels = np.concatenate((labels, np.zeros((max_bbox_count - len(labels), 1), - dtype=np.int32))) - else: - bboxes = np.zeros((max_bbox_count, 4), dtype=np.float32) - scores = np.zeros((max_bbox_count, 1), dtype=np.float32) - labels = np.zeros((max_bbox_count, 1), dtype=np.int32) + outputs = [] + predictions = self.model(input_data) + for inp_data, preds in zip(input_data, predictions.xyxy): + preds = preds.cpu().numpy() + labels = [[pred[5]] for pred in preds] + scores = [[pred[4]] for pred in preds] + h, w, _ = inp_data.shape # input image shape + bboxes = [[x[1] / h, x[0] / w, x[3] / h, x[2] / w] + for x in preds] # normalize the bboxes to [0,1] + if len(bboxes) != 0: + bboxes = np.concatenate((bboxes, np.zeros((max_bbox_count - len(bboxes), 4)))) + scores = np.concatenate((scores, np.zeros((max_bbox_count - len(scores), 1)))) + labels = np.concatenate((labels, np.zeros( + (max_bbox_count - len(labels), 1), dtype=np.int32))) + else: + bboxes = np.zeros((max_bbox_count, 4), dtype=np.float32) + scores = np.zeros((max_bbox_count, 1), dtype=np.float32) + labels = np.zeros((max_bbox_count, 1), dtype=np.int32) + + outputs.append( + VisualDetectorOutput( + predicted_bboxes=bboxes, predicted_labels=labels, predicted_scores=scores)) - return VisualDetectorOutput( - predicted_bboxes=bboxes, predicted_labels=labels, predicted_scores=scores) + return outputs diff --git a/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/model.py b/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/model.py index 1cacacd2..36b54b37 100644 --- a/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/model.py +++ b/clarifai/models/model_serving/examples/visual_detection/yolov5x/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -37,14 +38,13 @@ def initialize(self, args): from inference import InferenceModel self.inference_obj = InferenceModel() - self.device = "cuda:0" if "GPU" in args["model_instance_kind"] else "cpu" # Read input_name from config file self.config_msg = ModelConfig() with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -53,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/visual_detection/yolov5x/config.pbtxt b/clarifai/models/model_serving/examples/visual_detection/yolov5x/config.pbtxt index aec97104..972cac27 100644 --- a/clarifai/models/model_serving/examples/visual_detection/yolov5x/config.pbtxt +++ b/clarifai/models/model_serving/examples/visual_detection/yolov5x/config.pbtxt @@ -1,4 +1,4 @@ -name: "yolov5_test" +name: "yolov5x" max_batch_size: 1 input { name: "image" diff --git a/clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt b/clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt index 3d9e50ad..c3719ddd 100644 --- a/clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt +++ b/clarifai/models/model_serving/examples/visual_detection/yolov5x/requirements.txt @@ -1,6 +1,6 @@ # YOLOv5 requirements tritonclient[all] -clarifai>9.5.3 # for model upload features +clarifai>9.10.5 matplotlib>=3.2.2 opencv-python>=4.1.1 Pillow>=7.1.2 diff --git a/clarifai/models/model_serving/examples/visual_embedding/README.md b/clarifai/models/model_serving/examples/visual_embedding/README.md index 5e4a3e5b..570908ac 100644 --- a/clarifai/models/model_serving/examples/visual_embedding/README.md +++ b/clarifai/models/model_serving/examples/visual_embedding/README.md @@ -6,4 +6,7 @@ These can be used on the fly with minimal or no changes to test deploy visual em Requirements to run tests locally: - * Download/Clone the [huggingface model](https://huggingface.co/google/vit-base-patch16-224) into the **vit-base/1/** directory then start the triton server. + * Download the [model checkpoint & sentencepiece bpe model from huggingface](https://huggingface.co/google/vit-base-patch16-224/tree/main) and store it under `vit-base/1/checkpoint` + ``` + huggingface-cli download google/vit-base-patch16-224 --local-dir vit-base/1/checkpoint --local-dir-use-symlinks False --exclude *.msgpack *.h5 *.safetensors + ``` diff --git a/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/inference.py b/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/inference.py index 29353783..e636c5aa 100644 --- a/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/inference.py +++ b/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/inference.py @@ -13,9 +13,11 @@ import torch from transformers import AutoModel, ViTImageProcessor -from clarifai.models.model_serving.models.model_types import visual_embedder +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import EmbeddingOutput +config = get_model_config(ModelTypes.visual_embedder) + class InferenceModel: """User model inference class.""" @@ -26,11 +28,11 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.huggingface_model_path = os.path.join(self.base_path, "vit-base-patch16-224") + self.huggingface_model_path = os.path.join(self.base_path, "checkpoint") self.processor = ViTImageProcessor.from_pretrained(self.huggingface_model_path) self.model = AutoModel.from_pretrained(self.huggingface_model_path) - @visual_embedder + @config.inference.wrap_func def get_predictions(self, input_data): """ Main model inference method. @@ -44,8 +46,11 @@ def get_predictions(self, input_data): -------- One of the clarifai.models.model_serving.models.output types. Refer to the README/docs """ + outputs = [] inputs = self.processor(images=input_data, return_tensors="pt") with torch.no_grad(): - embedding_vector = self.model(**inputs).last_hidden_state[:, 0].cpu().numpy() + embedding_vectors = self.model(**inputs).last_hidden_state[:, 0].cpu().numpy() + for embedding_vector in embedding_vectors: + outputs.append(EmbeddingOutput(embedding_vector=embedding_vector)) - return EmbeddingOutput(embedding_vector=embedding_vector[0]) + return outputs diff --git a/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/model.py b/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/model.py index 5c9a230f..36b54b37 100644 --- a/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/model.py +++ b/clarifai/models/model_serving/examples/visual_embedding/vit-base/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -43,7 +44,7 @@ def initialize(self, args): with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -52,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt b/clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt index eab1ee71..5f1e82a0 100644 --- a/clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt +++ b/clarifai/models/model_serving/examples/visual_embedding/vit-base/requirements.txt @@ -1,4 +1,4 @@ -clarifai>9.5.3 +clarifai>9.10.5 tritonclient[all] torch==1.13.1 transformers==4.30.2 diff --git a/clarifai/models/model_serving/examples/visual_segmentation/README.md b/clarifai/models/model_serving/examples/visual_segmentation/README.md index 2fd49f53..b6eb2bcc 100644 --- a/clarifai/models/model_serving/examples/visual_segmentation/README.md +++ b/clarifai/models/model_serving/examples/visual_segmentation/README.md @@ -6,4 +6,7 @@ These can be used on the fly with minimal or no changes to test deploy visual se Requirements to run tests locally: - * Download/Clone the [huggingface model](https://huggingface.co/mattmdjaga/segformer_b2_clothes) into the **segformer-b2/1/** directory then start the triton server. + * Download/Clone the [huggingface model](https://huggingface.co/mattmdjaga/segformer_b2_clothes) into the **segformer-b2/1/checkpoint** directory then start the triton server. + ``` + huggingface-cli download mattmdjaga/segformer_b2_clothes --local-dir segformer-b2/1/checkpoint --local-dir-use-symlinks False --exclude *.safetensors optimizer.pt + ``` diff --git a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/inference.py b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/inference.py index 004cc666..f5580ce5 100644 --- a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/inference.py +++ b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/inference.py @@ -13,9 +13,11 @@ import torch from transformers import AutoModelForSemanticSegmentation, SegformerImageProcessor -from clarifai.models.model_serving.models.model_types import visual_segmenter +from clarifai.models.model_serving.model_config import ModelTypes, get_model_config from clarifai.models.model_serving.models.output import MasksOutput +config = get_model_config(ModelTypes.visual_segmenter) + class InferenceModel: """User model inference class.""" @@ -26,30 +28,35 @@ def __init__(self) -> None: in this method so they are loaded only once for faster inference. """ self.base_path: Path = os.path.dirname(__file__) - self.huggingface_model_path = os.path.join(self.base_path, "segformer_b2_clothes") + self.huggingface_model_path = os.path.join(self.base_path, "checkpoint") #self.labels_path = os.path.join(Path(self.base_path).parents[0], "labels.txt") self.processor = SegformerImageProcessor.from_pretrained(self.huggingface_model_path) self.model = AutoModelForSemanticSegmentation.from_pretrained(self.huggingface_model_path) - @visual_segmenter - def get_predictions(self, input_data): + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ + outputs = [] + inputs = self.processor(images=input_data, return_tensors="pt") with torch.no_grad(): output = self.model(**inputs) - logits = output.logits.cpu() - mask = logits.argmax(dim=1)[0].numpy() + for logit in logits: + mask = logit.argmax(dim=0).numpy() + outputs.append(MasksOutput(predicted_mask=mask)) - return MasksOutput(predicted_mask=mask) + return outputs diff --git a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/model.py b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/model.py index 5c9a230f..36b54b37 100644 --- a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/model.py +++ b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/1/model.py @@ -21,6 +21,7 @@ pass from google.protobuf import text_format from tritonclient.grpc.model_config_pb2 import ModelConfig +from clarifai.models.model_serving.model_config.inference_parameter import parse_req_parameters class TritonPythonModel: @@ -43,7 +44,7 @@ def initialize(self, args): with open(os.path.join(args["model_repository"], "config.pbtxt"), "r") as f: cfg = f.read() text_format.Merge(cfg, self.config_msg) - self.input_name = [inp.name for inp in self.config_msg.input][0] + self.input_names = [inp.name for inp in self.config_msg.input] def execute(self, requests): """ @@ -52,9 +53,22 @@ def execute(self, requests): responses = [] for request in requests: - in_batch = pb_utils.get_input_tensor_by_name(request, self.input_name) - in_batch = in_batch.as_numpy() - inference_response = self.inference_obj.get_predictions(in_batch) + parameters = request.parameters() + parameters = parse_req_parameters(parameters) if parameters else {} + + if len(self.input_names) == 1: + in_batch = pb_utils.get_input_tensor_by_name(request, self.input_names[0]) + in_batch = in_batch.as_numpy() + inference_response = self.inference_obj.get_predictions(in_batch, **parameters) + else: + multi_in_batch_dict = {} + for input_name in self.input_names: + in_batch = pb_utils.get_input_tensor_by_name(request, input_name) + in_batch = in_batch.as_numpy() if in_batch is not None else [] + multi_in_batch_dict.update({input_name: in_batch}) + + inference_response = self.inference_obj.get_predictions(multi_in_batch_dict, **parameters) + responses.append(inference_response) return responses diff --git a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt index 25f5d9ab..2f45cf6e 100644 --- a/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt +++ b/clarifai/models/model_serving/examples/visual_segmentation/segformer-b2/requirements.txt @@ -1,5 +1,5 @@ torch==1.13.1 -clarifai>9.5.3 +clarifai>9.10.5 tritonclient[all] transformers==4.30.2 Pillow==10.0.0 diff --git a/clarifai/models/model_serving/examples/vllm/example/1/inference.py b/clarifai/models/model_serving/examples/vllm/example/1/inference.py index e27cc6d4..e5506e83 100644 --- a/clarifai/models/model_serving/examples/vllm/example/1/inference.py +++ b/clarifai/models/model_serving/examples/vllm/example/1/inference.py @@ -10,7 +10,6 @@ import os from pathlib import Path -import numpy as np from vllm import LLM, SamplingParams from clarifai.models.model_serving.model_config import ModelTypes, get_model_config @@ -37,7 +36,7 @@ def __init__(self) -> None: ) @config.inference.wrap_func - def get_predictions(self, input_data, **kwargs): + def get_predictions(self, input_data: list, **kwargs): """ Main model inference method. @@ -48,10 +47,10 @@ def get_predictions(self, input_data, **kwargs): Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ sampling_params = SamplingParams(**kwargs) - output = self.model.generate(input_data, sampling_params) - generated_text = np.asarray(output[0].outputs[0].text, dtype=object) + preds = self.model.generate(input_data, sampling_params) + outputs = [config.inference.return_type(each.outputs[0].text) for each in preds] - return config.inference.return_type(generated_text) + return outputs From fd022747f5c01ab9a3d7330cfff3eb1fe2769f22 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 12:41:09 +0700 Subject: [PATCH 5/8] update doc --- clarifai/models/model_serving/README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/clarifai/models/model_serving/README.md b/clarifai/models/model_serving/README.md index f3e38955..a23e219e 100644 --- a/clarifai/models/model_serving/README.md +++ b/clarifai/models/model_serving/README.md @@ -81,7 +81,6 @@ A generated triton model repository looks as illustrated in the directory tree a | `config.pbtxt` | Contains the triton model configuration used by the triton inference server to guide inference requests processing. | | `requirements.txt` | Contains dependencies needed by a user model to successfully make predictions.| | `labels.txt` | Contains labels listed one per line, a model is trained to predict. The order of labels should match the model predicted class indexes. | -| `triton_conda.yaml` | Contains dependencies available in pre-configured execution environment. | | `1/inference.py` | The inference script where users write their inference code. | | `1/model.py` | The triton python backend model file run to serve inference requests. | | `1/test.py` | Contains some predefined tests in order to test inference implementation and dependencies locally. | @@ -97,7 +96,11 @@ This script is composed of a single class that contains a default init method an import os from pathlib import Path -from typing import Callable + +from clarifai.models.model_serving.model_config import (ModelTypes, get_model_config) + +config = get_model_config("clarifai-model-type") # Input your model type + class InferenceModel: """User model inference class.""" @@ -112,29 +115,32 @@ class InferenceModel: #self.checkpoint_path: Path = os.path.join(self.base_path, "your checkpoint filename/path") #self.model: Callable = - #Add relevant model type decorator to the method below (see docs/model_types for ref.) - def get_predictions(self, input_data, **kwargs): + @config.inference.wrap_func + def get_predictions(self, input_data: list, **kwargs) -> list: """ Main model inference method. Args: ----- - input_data: A single input data item to predict on. + input_data: A list of input data item to predict on. Input data can be an image or text, etc depending on the model type. + **kwargs: your inference parameters. + Returns: -------- - One of the clarifai.models.model_serving.models.output types. Refer to the README/docs + List of one of the `clarifai.models.model_serving.models.output types` or `config.inference.return_type(your_output)`. Refer to the README/docs """ + # Delete/Comment out line below and add your inference code raise NotImplementedError() ``` - `__init__()` used for one-time loading of inference time artifacts such as models, tokenizers, etc that are frequently called during inference to improve inference speed. -- `get_predictions()` takes an input data item whose type depends on the task the model solves, & returns predictions for an input data item. +- `get_predictions()` takes a list of input data items whose type depends on the task the model solves, & returns list of predictions. -`get_predictions()` should return any of the output types defined under [output](docs/output.md) and the predict function MUST be decorated with a task corresponding [model type decorator](docs/model_types.md). The model type decorators are responsible for passing input request batches for prediction and formatting the resultant predictions into triton inference responses. +`get_predictions()` should return a list of any of the output types defined under [output](docs/output.md) and the predict function MUST be decorated with a task corresponding [@config.inference.wrap_func](docs/model_types.md). The model type decorators are responsible for passing input request batches for prediction and formatting the resultant predictions into triton inference responses. Additional methods can be added to this script's `Infer` class by the user as deemed necessary for their model inference provided they are invoked inside `get_predictions()` if used at inference time. From db8c11b93a3224dea91022ce4fd0f85c4e1df651 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 12:42:47 +0700 Subject: [PATCH 6/8] update clarifai version --- clarifai/models/model_serving/pb_model_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clarifai/models/model_serving/pb_model_repository.py b/clarifai/models/model_serving/pb_model_repository.py index e42cd6e8..ef4530d4 100644 --- a/clarifai/models/model_serving/pb_model_repository.py +++ b/clarifai/models/model_serving/pb_model_repository.py @@ -82,7 +82,7 @@ def build_repository(self, repository_dir: Path = os.curdir): continue # gen requirements with open(os.path.join(repository_path, "requirements.txt"), "w") as f: - f.write("clarifai>9.10.3\ntritonclient[all]") # for model upload utils + f.write("clarifai>9.10.4\ntritonclient[all]") # for model upload utils if not os.path.isdir(model_version_path): os.mkdir(model_version_path) From d329119ec59fd7658e1a972e518162b959850d07 Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 12:47:56 +0700 Subject: [PATCH 7/8] enable infer param description from_kwargs --- .../models/model_serving/model_config/inference_parameter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clarifai/models/model_serving/model_config/inference_parameter.py b/clarifai/models/model_serving/model_config/inference_parameter.py index e3b495e1..b3e25199 100644 --- a/clarifai/models/model_serving/model_config/inference_parameter.py +++ b/clarifai/models/model_serving/model_config/inference_parameter.py @@ -59,7 +59,7 @@ def from_kwargs(cls, **kwargs): _type = InferParamType.NUMBER else: raise TypeError(f"Unsupported type {type(v)} of argument {k}, support {InferParamType}") - param = InferParam(path=k, field_type=_type, default_value=v, description="") + param = InferParam(path=k, field_type=_type, default_value=v, description=k) params.append(param) return cls(params=params) From 16128222ea320c40860ccea5b42657afaa93de4a Mon Sep 17 00:00:00 2001 From: phatvo9 Date: Tue, 28 Nov 2023 14:30:32 +0700 Subject: [PATCH 8/8] update infer param doc --- clarifai/models/model_serving/docs/inference_parameters.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clarifai/models/model_serving/docs/inference_parameters.md b/clarifai/models/model_serving/docs/inference_parameters.md index 02adbb41..01199e55 100644 --- a/clarifai/models/model_serving/docs/inference_parameters.md +++ b/clarifai/models/model_serving/docs/inference_parameters.md @@ -83,8 +83,8 @@ ipm.export("your_file.json") ``` ##### 2.2. Shorten -`NOTE`: in this way `description` field will be set as empty aka "". -*You need to modify* `description` in order to be able to upload the settings to Clarifai. +`NOTE`: in this way `description` field will be set as ~~empty aka "". +*You need to modify* `description` in order to be able to upload the settings to Clarifai.~~ the key of dictionary. `NOTE`: in this way `ENCRYPTED_STRING` type must be defined with "_" prefix