diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e0d9702..6c368034 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,4 +42,4 @@ repos: rev: 4.0.1 hooks: - id: flake8 - args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401'] + args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821'] diff --git a/examples/local/fill-mask-example.py b/examples/local/fill-mask-example.py index 27f94aca..01b58e30 100644 --- a/examples/local/fill-mask-example.py +++ b/examples/local/fill-mask-example.py @@ -1,9 +1,19 @@ import mii +import argparse -# roberta -name = "roberta-base" -name = "bert-base-cased" +parser = argparse.ArgumentParser() +parser.add_argument("-q", "--query", action="store_true", help="query") +args = parser.parse_args() -print(f"Deploying {name}...") +name = "bert-base-uncased" +mask = "[MASK]" -mii.deploy(task='fill-mask', model=name, deployment_name=name + "_deployment") +if not args.query: + print(f"Deploying {name}...") + mii.deploy(task='fill-mask', model=name, deployment_name=name + "_deployment") +else: + print(f"Querying {name}...") + generator = mii.mii_query_handle(name + "_deployment") + result = generator.query({'query': f"Hello I'm a {mask} model."}) + print(result.response) + print("time_taken:", result.time_taken) diff --git a/examples/local/fill-mask-query-example.py b/examples/local/fill-mask-query-example.py deleted file mode 100644 index 8b36af1f..00000000 --- a/examples/local/fill-mask-query-example.py +++ /dev/null @@ -1,17 +0,0 @@ -import mii - -# roberta -name = "roberta-base" -name = "roberta-large" -mask = "" -# bert -name = "bert-base-cased" -mask = "[MASK]" -print(f"Querying {name}...") - -generator = mii.mii_query_handle(name + "_deployment") - -result = generator.query({'query': "Hello I'm a " + mask + " model."}) -print(result.response) -print("time_taken:", result.time_taken) -print("model_time_taken:", result.model_time_taken) diff --git a/examples/local/txt2img-example.py b/examples/local/txt2img-example.py new file mode 100644 index 00000000..24f90dec --- /dev/null +++ b/examples/local/txt2img-example.py @@ -0,0 +1,39 @@ +import os +import mii +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("-q", "--query", action="store_true", help="query") +args = parser.parse_args() + +if not args.query: + mii_configs = { + "tensor_parallel": + 1, + "dtype": + "fp16", + "hf_auth_token": + os.environ.get("HF_AUTH_TOKEN", + "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"), + "port_number": + 50050 + } + mii.deploy(task='text-to-image', + model="CompVis/stable-diffusion-v1-4", + deployment_name="sd_deploy", + mii_config=mii_configs) + print( + "\nText to image model deployment complete! To use this deployment, run the following command: python txt2img-example.py --query\n" + ) +else: + generator = mii.mii_query_handle("sd_deploy") + result = generator.query({ + 'query': + ["a panda in space with a rainbow", + "a soda can on top a snowy mountain"] + }) + from PIL import Image + for idx, img_bytes in enumerate(result.images): + size = (result.size_w, result.size_h) + img = Image.frombytes(result.mode, size, img_bytes) + img.save(f"test-{idx}.png") diff --git a/mii/config.py b/mii/config.py index b7144ef4..56812959 100644 --- a/mii/config.py +++ b/mii/config.py @@ -11,6 +11,7 @@ class MIIConfig(BaseModel): checkpoint_dict: Union[dict, None] = None deploy_rank: Union[int, List[int]] = -1 torch_dist_port: int = 29500 + hf_auth_token: str = None replace_with_kernel_inject: bool = True profile_model_time: bool = False diff --git a/mii/constants.py b/mii/constants.py index 1248be49..1c990d46 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -17,6 +17,7 @@ class Tasks(enum.Enum): FILL_MASK = 4 TOKEN_CLASSIFICATION = 5 CONVERSATIONAL = 6 + TEXT2IMG = 7 TEXT_GENERATION_NAME = 'text-generation' @@ -25,22 +26,26 @@ class Tasks(enum.Enum): FILL_MASK_NAME = 'fill-mask' TOKEN_CLASSIFICATION_NAME = 'token-classification' CONVERSATIONAL_NAME = 'conversational' +TEXT2IMG_NAME = "text-to-image" class ModelProvider(enum.Enum): HUGGING_FACE = 1 ELEUTHER_AI = 2 HUGGING_FACE_LLM = 3 + DIFFUSERS = 4 MODEL_PROVIDER_NAME_HF = "hugging-face" MODEL_PROVIDER_NAME_EA = "eleuther-ai" MODEL_PROVIDER_NAME_HF_LLM = "hugging-face-llm" +MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers" MODEL_PROVIDER_MAP = { MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE, MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI, MODEL_PROVIDER_NAME_HF_LLM: ModelProvider.HUGGING_FACE_LLM, + MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS } SUPPORTED_MODEL_TYPES = { @@ -52,6 +57,7 @@ class ModelProvider(enum.Enum): 'opt': ModelProvider.HUGGING_FACE, 'gpt-neox': ModelProvider.ELEUTHER_AI, 'bloom': ModelProvider.HUGGING_FACE_LLM, + 'stable-diffusion': ModelProvider.DIFFUSERS } SUPPORTED_TASKS = [ @@ -60,7 +66,8 @@ class ModelProvider(enum.Enum): QUESTION_ANSWERING_NAME, FILL_MASK_NAME, TOKEN_CLASSIFICATION_NAME, - CONVERSATIONAL_NAME + CONVERSATIONAL_NAME, + TEXT2IMG_NAME ] REQUIRED_KEYS_PER_TASK = { @@ -74,7 +81,8 @@ class ModelProvider(enum.Enum): ['text', 'conversation_id', 'past_user_inputs', - 'generated_responses'] + 'generated_responses'], + TEXT2IMG_NAME: ["query"] } MODEL_NAME_KEY = 'model_name' @@ -98,3 +106,5 @@ class ModelProvider(enum.Enum): MII_DEBUG_BRANCH_DEFAULT = "main" MII_MODEL_PATH_DEFAULT = "/tmp/mii_models" + +GRPC_MAX_MSG_SIZE = 2**30 # 1GB diff --git a/mii/deployment.py b/mii/deployment.py index a83955d0..8a4f19fb 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -81,7 +81,14 @@ def deploy(task, if enable_deepspeed: mii.utils.check_if_task_and_model_is_supported(task, model) - logger.info(f"*************DeepSpeed Optimizations: {enable_deepspeed}*************") + if enable_deepspeed: + logger.info( + f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" + ) + else: + logger.info( + f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" + ) # In local deployments use default path if no model path set if model_path is None and deployment_type == DeploymentType.LOCAL: diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 8e7d86bd..08dff915 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -10,7 +10,9 @@ import sys import time +from torch import autocast from transformers import Conversation +from mii.constants import GRPC_MAX_MSG_SIZE class ModelResponse(modelresponse_pb2_grpc.ModelResponseServicer): @@ -66,6 +68,36 @@ def GeneratorReply(self, request, context): model_time_taken=model_time) return val + def Txt2ImgReply(self, request, context): + query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs) + + # unpack grpc list into py-list + request = [r for r in request.request] + + start = time.time() + with autocast("cuda"): + response = self.inference_pipeline(request, **query_kwargs) + end = time.time() + + images_bytes = [] + nsfw_content_detected = [] + response_count = len(response.images) + for i in range(response_count): + img = response.images[i] + img_bytes = img.tobytes() + images_bytes.append(img_bytes) + nsfw_content_detected.append(response.nsfw_content_detected[i]) + img_mode = response.images[0].mode + img_size_w, img_size_h = response.images[0].size + + val = modelresponse_pb2.ImageReply(images=images_bytes, + nsfw_content_detected=nsfw_content_detected, + mode=img_mode, + size_w=img_size_w, + size_h=img_size_h, + time_taken=end - start) + return val + def ClassificationReply(self, request, context): query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs) start = time.time() @@ -131,7 +163,11 @@ def ConversationalReply(self, request, context): def serve(inference_pipeline, port): - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), + options=[('grpc.max_send_message_length', + GRPC_MAX_MSG_SIZE), + ('grpc.max_receive_message_length', + GRPC_MAX_MSG_SIZE)]) modelresponse_pb2_grpc.add_ModelResponseServicer_to_server( ModelResponse(inference_pipeline), server) diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index 81107b7d..b941fc17 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -21,15 +21,14 @@ option objc_class_prefix = "HLW";*/ package modelresponse; -// The greeting service definition. service ModelResponse { - // Sends a greeting rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {} rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {} rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {} rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {} rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {} rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {} + rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {} } message Value { @@ -84,3 +83,12 @@ message ConversationReply { float time_taken = 4; float model_time_taken = 5; } + +message ImageReply { + repeated bytes images = 1; + repeated bool nsfw_content_detected = 2; + string mode = 3; + int64 size_w = 4; + int64 size_h = 5; + float time_taken = 6; +} diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 9d5aefb9..f479dc98 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,175 +1,20 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\x32\xba\x04\n\rModelResponse\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\x8a\x05\n\rModelResponse\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' ) -_VALUE = DESCRIPTOR.message_types_by_name['Value'] -_SINGLESTRINGREQUEST = DESCRIPTOR.message_types_by_name['SingleStringRequest'] -_SINGLESTRINGREQUEST_QUERYKWARGSENTRY = _SINGLESTRINGREQUEST.nested_types_by_name[ - 'QueryKwargsEntry'] -_MULTISTRINGREQUEST = DESCRIPTOR.message_types_by_name['MultiStringRequest'] -_MULTISTRINGREQUEST_QUERYKWARGSENTRY = _MULTISTRINGREQUEST.nested_types_by_name[ - 'QueryKwargsEntry'] -_SINGLESTRINGREPLY = DESCRIPTOR.message_types_by_name['SingleStringReply'] -_MULTISTRINGREPLY = DESCRIPTOR.message_types_by_name['MultiStringReply'] -_QAREQUEST = DESCRIPTOR.message_types_by_name['QARequest'] -_QAREQUEST_QUERYKWARGSENTRY = _QAREQUEST.nested_types_by_name['QueryKwargsEntry'] -_CONVERSATIONREQUEST = DESCRIPTOR.message_types_by_name['ConversationRequest'] -_CONVERSATIONREQUEST_QUERYKWARGSENTRY = _CONVERSATIONREQUEST.nested_types_by_name[ - 'QueryKwargsEntry'] -_CONVERSATIONREPLY = DESCRIPTOR.message_types_by_name['ConversationReply'] -Value = _reflection.GeneratedProtocolMessageType( - 'Value', - (_message.Message, - ), - { - 'DESCRIPTOR': _VALUE, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.Value) - }) -_sym_db.RegisterMessage(Value) - -SingleStringRequest = _reflection.GeneratedProtocolMessageType( - 'SingleStringRequest', - (_message.Message, - ), - { - 'QueryKwargsEntry': - _reflection.GeneratedProtocolMessageType( - 'QueryKwargsEntry', - (_message.Message, - ), - { - 'DESCRIPTOR': _SINGLESTRINGREQUEST_QUERYKWARGSENTRY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.SingleStringRequest.QueryKwargsEntry) - }), - 'DESCRIPTOR': - _SINGLESTRINGREQUEST, - '__module__': - 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.SingleStringRequest) - }) -_sym_db.RegisterMessage(SingleStringRequest) -_sym_db.RegisterMessage(SingleStringRequest.QueryKwargsEntry) - -MultiStringRequest = _reflection.GeneratedProtocolMessageType( - 'MultiStringRequest', - (_message.Message, - ), - { - 'QueryKwargsEntry': - _reflection.GeneratedProtocolMessageType( - 'QueryKwargsEntry', - (_message.Message, - ), - { - 'DESCRIPTOR': _MULTISTRINGREQUEST_QUERYKWARGSENTRY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.MultiStringRequest.QueryKwargsEntry) - }), - 'DESCRIPTOR': - _MULTISTRINGREQUEST, - '__module__': - 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.MultiStringRequest) - }) -_sym_db.RegisterMessage(MultiStringRequest) -_sym_db.RegisterMessage(MultiStringRequest.QueryKwargsEntry) - -SingleStringReply = _reflection.GeneratedProtocolMessageType( - 'SingleStringReply', - (_message.Message, - ), - { - 'DESCRIPTOR': _SINGLESTRINGREPLY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.SingleStringReply) - }) -_sym_db.RegisterMessage(SingleStringReply) - -MultiStringReply = _reflection.GeneratedProtocolMessageType( - 'MultiStringReply', - (_message.Message, - ), - { - 'DESCRIPTOR': _MULTISTRINGREPLY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.MultiStringReply) - }) -_sym_db.RegisterMessage(MultiStringReply) - -QARequest = _reflection.GeneratedProtocolMessageType( - 'QARequest', - (_message.Message, - ), - { - 'QueryKwargsEntry': - _reflection.GeneratedProtocolMessageType( - 'QueryKwargsEntry', - (_message.Message, - ), - { - 'DESCRIPTOR': _QAREQUEST_QUERYKWARGSENTRY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.QARequest.QueryKwargsEntry) - }), - 'DESCRIPTOR': - _QAREQUEST, - '__module__': - 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.QARequest) - }) -_sym_db.RegisterMessage(QARequest) -_sym_db.RegisterMessage(QARequest.QueryKwargsEntry) - -ConversationRequest = _reflection.GeneratedProtocolMessageType( - 'ConversationRequest', - (_message.Message, - ), - { - 'QueryKwargsEntry': - _reflection.GeneratedProtocolMessageType( - 'QueryKwargsEntry', - (_message.Message, - ), - { - 'DESCRIPTOR': _CONVERSATIONREQUEST_QUERYKWARGSENTRY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.ConversationRequest.QueryKwargsEntry) - }), - 'DESCRIPTOR': - _CONVERSATIONREQUEST, - '__module__': - 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.ConversationRequest) - }) -_sym_db.RegisterMessage(ConversationRequest) -_sym_db.RegisterMessage(ConversationRequest.QueryKwargsEntry) - -ConversationReply = _reflection.GeneratedProtocolMessageType( - 'ConversationReply', - (_message.Message, - ), - { - 'DESCRIPTOR': _CONVERSATIONREPLY, - '__module__': 'modelresponse_pb2' - # @@protoc_insertion_point(class_scope:modelresponse.ConversationReply) - }) -_sym_db.RegisterMessage(ConversationReply) - -_MODELRESPONSE = DESCRIPTOR.services_by_name['ModelResponse'] +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -205,6 +50,8 @@ _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_end = 323 _CONVERSATIONREPLY._serialized_start = 1163 _CONVERSATIONREPLY._serialized_end = 1308 - _MODELRESPONSE._serialized_start = 1311 - _MODELRESPONSE._serialized_end = 1881 + _IMAGEREPLY._serialized_start = 1310 + _IMAGEREPLY._serialized_end = 1435 + _MODELRESPONSE._serialized_start = 1438 + _MODELRESPONSE._serialized_end = 2088 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 808a0908..71f5b48f 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -6,8 +6,7 @@ class ModelResponseStub(object): - """The greeting service definition. - """ + """Missing associated documentation comment in .proto file.""" def __init__(self, channel): """Constructor. @@ -44,14 +43,17 @@ def __init__(self, channel): request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, response_deserializer=modelresponse__pb2.ConversationReply.FromString, ) + self.Txt2ImgReply = channel.unary_unary( + '/modelresponse.ModelResponse/Txt2ImgReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ImageReply.FromString, + ) class ModelResponseServicer(object): - """The greeting service definition. - """ + """Missing associated documentation comment in .proto file.""" def GeneratorReply(self, request, context): - """Sends a greeting - """ + """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') @@ -86,6 +88,12 @@ def ConversationalReply(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def Txt2ImgReply(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { @@ -125,6 +133,12 @@ def add_ModelResponseServicer_to_server(servicer, server): request_deserializer=modelresponse__pb2.ConversationRequest.FromString, response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, ), + 'Txt2ImgReply': + grpc.unary_unary_rpc_method_handler( + servicer.Txt2ImgReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', rpc_method_handlers) @@ -133,8 +147,7 @@ def add_ModelResponseServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class ModelResponse(object): - """The greeting service definition. - """ + """Missing associated documentation comment in .proto file.""" @staticmethod def GeneratorReply(request, target, @@ -290,3 +303,29 @@ def ConversationalReply(request, wait_for_ready, timeout, metadata) + + @staticmethod + def Txt2ImgReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/modelresponse.ModelResponse/Txt2ImgReply', + modelresponse__pb2.MultiStringRequest.SerializeToString, + modelresponse__pb2.ImageReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) diff --git a/mii/models/load_models.py b/mii/models/load_models.py index c480b21d..3d176e81 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -22,8 +22,14 @@ def load_models(task_name, global generator world_size = int(os.getenv('WORLD_SIZE', '1')) - #TODO: pass in mii_config to fetch dtype, training_mp_size, and other params - ds_kwargs = {"checkpoint": None, "mpu": None, "args": None, "training_mp_size": 1} + ds_kwargs = { + "checkpoint": None, + "mpu": None, + "args": None, + "training_mp_size": 1, + "replace_with_kernel_inject": mii_config.replace_with_kernel_inject + } + if provider == mii.constants.ModelProvider.HUGGING_FACE: from mii.models.providers.huggingface import hf_provider inference_pipeline = hf_provider(model_path, model_name, task_name, mii_config) @@ -49,6 +55,14 @@ def load_models(task_name, if "enable_qkv_quantization" in inspect.signature( deepspeed.init_inference).parameters: ds_kwargs["enable_qkv_quantization"] = True + elif provider == mii.constants.ModelProvider.DIFFUSERS: + from mii.models.providers.diffusers import diffusers_provider + assert not mii_config.enable_cuda_graph, "Diffusers models do no support Cuda Graphs (yet)" + inference_pipeline = diffusers_provider(model_path, + model_name, + task_name, + mii_config) + ds_kwargs["replace_with_kernel_inject"] = False #not supported yet else: raise ValueError(f"Unknown model provider {provider}") @@ -56,16 +70,19 @@ def load_models(task_name, f"> --------- MII Settings: {ds_optimize=}, replace_with_kernel_inject={mii_config.replace_with_kernel_inject}, enable_cuda_graph={mii_config.enable_cuda_graph} " ) if ds_optimize: - inference_pipeline.model = deepspeed.init_inference( - inference_pipeline.model, - mp_size=world_size, - dtype=mii_config.torch_dtype(), - replace_with_kernel_inject=mii_config.replace_with_kernel_inject, - replace_method='auto', - enable_cuda_graph=mii_config.enable_cuda_graph, - **ds_kwargs) + engine = deepspeed.init_inference(getattr(inference_pipeline, + "model", + inference_pipeline), + mp_size=world_size, + dtype=mii_config.torch_dtype(), + replace_method='auto', + enable_cuda_graph=mii_config.enable_cuda_graph, + **ds_kwargs) if mii_config.profile_model_time: - inference_pipeline.model.profile_model_time() + engine.profile_model_time() + if hasattr(inference_pipeline, "model"): + inference_pipeline.model = engine + elif ds_zero: ds_config = DeepSpeedConfig(ds_config_path) #TODO: don't read ds-config from disk, we should pass this around as a dict instead diff --git a/mii/models/providers/diffusers.py b/mii/models/providers/diffusers.py new file mode 100644 index 00000000..daf9e5f5 --- /dev/null +++ b/mii/models/providers/diffusers.py @@ -0,0 +1,10 @@ +import os + + +def diffusers_provider(model_path, model_name, task_name, mii_config): + from diffusers import DiffusionPipeline + local_rank = int(os.getenv('LOCAL_RANK', '0')) + pipeline = DiffusionPipeline.from_pretrained(model_name, + use_auth_token=mii_config.hf_auth_token) + pipeline = pipeline.to(f"cuda:{local_rank}") + return pipeline diff --git a/mii/server_client.py b/mii/server_client.py index 254a7c7b..2a89eeb8 100644 --- a/mii/server_client.py +++ b/mii/server_client.py @@ -184,6 +184,8 @@ def _initialize_service(self, provider = mii.constants.MODEL_PROVIDER_NAME_EA elif ("bigscience/bloom" == model_name) or ("microsoft/bloom" in model_name): provider = mii.constants.MODEL_PROVIDER_NAME_HF_LLM + elif self.task == mii.Tasks.TEXT2IMG: + provider = mii.constants.MODEL_PROVIDER_NAME_DIFFUSERS else: provider = mii.constants.MODEL_PROVIDER_NAME_HF server_args_str += f" --provider {provider}" @@ -211,12 +213,27 @@ def create_config_from_dict(tmpdir, config_dict): ) server_args_str += f" --ds-config {ds_config_path}" cmd = f'{ds_launch_str} {launch_str} {server_args_str}'.split(" ") - logger.info(f"multi-gpu deepspeed launch: {cmd}") + printable_config = f"task-name {mii.utils.get_task_name(self.task)} model {model_name} model-path {model_path} port {self.port_number} provider {provider}" + logger.info(f"MII using multi-gpu deepspeed launcher:\n" + + self.print_helper(printable_config)) mii_env = os.environ.copy() mii_env["TRANSFORMERS_CACHE"] = model_path process = subprocess.Popen(cmd, env=mii_env) return process + def print_helper(self, args): + # convert to list + args = args.split(" ") + # convert to dict + dct = {args[i]: args[i + 1] for i in range(0, len(args), 2)} + printable_string = "" + printable_string += " " + "-" * 60 + "\n" + for k, v in dct.items(): + dots = "." * (29 - len(k)) + printable_string += f" {k} {dots} {v} \n" + printable_string += " " + "-" * 60 + return printable_string + def _initialize_grpc_client(self): channels = [] for i in range(self.num_gpus): @@ -279,8 +296,16 @@ async def _request_async_response(self, stub_id, request_dict, query_kwargs): generated_responses=request_dict['generated_responses'], query_kwargs=proto_kwargs)) + elif self.task == mii.Tasks.TEXT2IMG: + # convert to batch of queries if they are not already + if not isinstance(request_dict['query'], list): + request_dict['query'] = [request_dict['query']] + req = modelresponse_pb2.MultiStringRequest(request=request_dict['query'], + query_kwargs=proto_kwargs) + response = await self.stubs[stub_id].Txt2ImgReply(req) + else: - assert False, "unknown task" + raise ValueError(f"unknown task: {self.task}") return response def _request_response(self, request_dict, query_kwargs): @@ -305,6 +330,9 @@ def _request_response(self, request_dict, query_kwargs): elif self.task == mii.Tasks.CONVERSATIONAL: response = self.model(["", request_dict['query']], **query_kwargs) + elif self.task == mii.Tasks.TEXT2IMG: + response = self.model(request_dict['query'], **query_kwargs) + else: raise NotImplementedError(f"task is not supported: {self.task}") end = time.time() diff --git a/mii/utils.py b/mii/utils.py index ebe441e1..36704c15 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -9,19 +9,18 @@ from huggingface_hub import HfApi -from mii.constants import ( - CONVERSATIONAL_NAME, - FILL_MASK_NAME, - MII_CACHE_PATH, - MII_CACHE_PATH_DEFAULT, - TEXT_GENERATION_NAME, - TEXT_CLASSIFICATION_NAME, - QUESTION_ANSWERING_NAME, - TOKEN_CLASSIFICATION_NAME, - SUPPORTED_MODEL_TYPES, - ModelProvider, - REQUIRED_KEYS_PER_TASK, -) +from mii.constants import (CONVERSATIONAL_NAME, + FILL_MASK_NAME, + MII_CACHE_PATH, + MII_CACHE_PATH_DEFAULT, + TEXT_GENERATION_NAME, + TEXT_CLASSIFICATION_NAME, + QUESTION_ANSWERING_NAME, + TOKEN_CLASSIFICATION_NAME, + SUPPORTED_MODEL_TYPES, + ModelProvider, + REQUIRED_KEYS_PER_TASK, + TEXT2IMG_NAME) from mii.constants import Tasks @@ -45,6 +44,9 @@ def get_task_name(task): if task == Tasks.CONVERSATIONAL: return CONVERSATIONAL_NAME + if task == Tasks.TEXT2IMG: + return TEXT2IMG_NAME + raise ValueError(f"Unknown Task {task}") @@ -67,6 +69,9 @@ def get_task(task_name): if task_name == CONVERSATIONAL_NAME: return Tasks.CONVERSATIONAL + if task_name == TEXT2IMG_NAME: + return Tasks.TEXT2IMG + assert False, f"Unknown Task {task_name}" @@ -95,6 +100,8 @@ def _get_supported_models_name(task): elif provider == ModelProvider.ELEUTHER_AI: if task_name == TEXT_GENERATION_NAME: models = [model_type] + elif provider == ModelProvider.DIFFUSERS: + models = _get_hf_models_by_type(model_type, task_name) supported_models.extend(models) if not supported_models: raise ValueError(f"Task {task} not supported")