Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401,mii/grpc_related/proto/modelresponse_pb2.py:F821']
20 changes: 15 additions & 5 deletions examples/local/fill-mask-example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import mii
import argparse

# roberta
name = "roberta-base"
name = "bert-base-cased"
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--query", action="store_true", help="query")
args = parser.parse_args()

print(f"Deploying {name}...")
name = "bert-base-uncased"
mask = "[MASK]"

mii.deploy(task='fill-mask', model=name, deployment_name=name + "_deployment")
if not args.query:
print(f"Deploying {name}...")
mii.deploy(task='fill-mask', model=name, deployment_name=name + "_deployment")
else:
print(f"Querying {name}...")
generator = mii.mii_query_handle(name + "_deployment")
result = generator.query({'query': f"Hello I'm a {mask} model."})
print(result.response)
print("time_taken:", result.time_taken)
17 changes: 0 additions & 17 deletions examples/local/fill-mask-query-example.py

This file was deleted.

39 changes: 39 additions & 0 deletions examples/local/txt2img-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import mii
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-q", "--query", action="store_true", help="query")
args = parser.parse_args()

if not args.query:
mii_configs = {
"tensor_parallel":
1,
"dtype":
"fp16",
"hf_auth_token":
os.environ.get("HF_AUTH_TOKEN",
"hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"),
"port_number":
50050
}
mii.deploy(task='text-to-image',
model="CompVis/stable-diffusion-v1-4",
deployment_name="sd_deploy",
mii_config=mii_configs)
print(
"\nText to image model deployment complete! To use this deployment, run the following command: python txt2img-example.py --query\n"
)
else:
generator = mii.mii_query_handle("sd_deploy")
result = generator.query({
'query':
["a panda in space with a rainbow",
"a soda can on top a snowy mountain"]
})
from PIL import Image
for idx, img_bytes in enumerate(result.images):
size = (result.size_w, result.size_h)
img = Image.frombytes(result.mode, size, img_bytes)
img.save(f"test-{idx}.png")
1 change: 1 addition & 0 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class MIIConfig(BaseModel):
checkpoint_dict: Union[dict, None] = None
deploy_rank: Union[int, List[int]] = -1
torch_dist_port: int = 29500
hf_auth_token: str = None
replace_with_kernel_inject: bool = True
profile_model_time: bool = False

Expand Down
14 changes: 12 additions & 2 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Tasks(enum.Enum):
FILL_MASK = 4
TOKEN_CLASSIFICATION = 5
CONVERSATIONAL = 6
TEXT2IMG = 7


TEXT_GENERATION_NAME = 'text-generation'
Expand All @@ -25,22 +26,26 @@ class Tasks(enum.Enum):
FILL_MASK_NAME = 'fill-mask'
TOKEN_CLASSIFICATION_NAME = 'token-classification'
CONVERSATIONAL_NAME = 'conversational'
TEXT2IMG_NAME = "text-to-image"


class ModelProvider(enum.Enum):
HUGGING_FACE = 1
ELEUTHER_AI = 2
HUGGING_FACE_LLM = 3
DIFFUSERS = 4


MODEL_PROVIDER_NAME_HF = "hugging-face"
MODEL_PROVIDER_NAME_EA = "eleuther-ai"
MODEL_PROVIDER_NAME_HF_LLM = "hugging-face-llm"
MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers"

MODEL_PROVIDER_MAP = {
MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE,
MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI,
MODEL_PROVIDER_NAME_HF_LLM: ModelProvider.HUGGING_FACE_LLM,
MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS
}

SUPPORTED_MODEL_TYPES = {
Expand All @@ -52,6 +57,7 @@ class ModelProvider(enum.Enum):
'opt': ModelProvider.HUGGING_FACE,
'gpt-neox': ModelProvider.ELEUTHER_AI,
'bloom': ModelProvider.HUGGING_FACE_LLM,
'stable-diffusion': ModelProvider.DIFFUSERS
}

SUPPORTED_TASKS = [
Expand All @@ -60,7 +66,8 @@ class ModelProvider(enum.Enum):
QUESTION_ANSWERING_NAME,
FILL_MASK_NAME,
TOKEN_CLASSIFICATION_NAME,
CONVERSATIONAL_NAME
CONVERSATIONAL_NAME,
TEXT2IMG_NAME
]

REQUIRED_KEYS_PER_TASK = {
Expand All @@ -74,7 +81,8 @@ class ModelProvider(enum.Enum):
['text',
'conversation_id',
'past_user_inputs',
'generated_responses']
'generated_responses'],
TEXT2IMG_NAME: ["query"]
}

MODEL_NAME_KEY = 'model_name'
Expand All @@ -98,3 +106,5 @@ class ModelProvider(enum.Enum):
MII_DEBUG_BRANCH_DEFAULT = "main"

MII_MODEL_PATH_DEFAULT = "/tmp/mii_models"

GRPC_MAX_MSG_SIZE = 2**30 # 1GB
9 changes: 8 additions & 1 deletion mii/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,14 @@ def deploy(task,
if enable_deepspeed:
mii.utils.check_if_task_and_model_is_supported(task, model)

logger.info(f"*************DeepSpeed Optimizations: {enable_deepspeed}*************")
if enable_deepspeed:
logger.info(
f"************* MII is using DeepSpeed Optimizations to accelerate your model *************"
)
else:
logger.info(
f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************"
)

# In local deployments use default path if no model path set
if model_path is None and deployment_type == DeploymentType.LOCAL:
Expand Down
38 changes: 37 additions & 1 deletion mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import sys
import time

from torch import autocast
from transformers import Conversation
from mii.constants import GRPC_MAX_MSG_SIZE


class ModelResponse(modelresponse_pb2_grpc.ModelResponseServicer):
Expand Down Expand Up @@ -66,6 +68,36 @@ def GeneratorReply(self, request, context):
model_time_taken=model_time)
return val

def Txt2ImgReply(self, request, context):
query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs)

# unpack grpc list into py-list
request = [r for r in request.request]

start = time.time()
with autocast("cuda"):
response = self.inference_pipeline(request, **query_kwargs)
end = time.time()

images_bytes = []
nsfw_content_detected = []
response_count = len(response.images)
for i in range(response_count):
img = response.images[i]
img_bytes = img.tobytes()
images_bytes.append(img_bytes)
nsfw_content_detected.append(response.nsfw_content_detected[i])
img_mode = response.images[0].mode
img_size_w, img_size_h = response.images[0].size

val = modelresponse_pb2.ImageReply(images=images_bytes,
nsfw_content_detected=nsfw_content_detected,
mode=img_mode,
size_w=img_size_w,
size_h=img_size_h,
time_taken=end - start)
return val

def ClassificationReply(self, request, context):
query_kwargs = self._unpack_proto_query_kwargs(request.query_kwargs)
start = time.time()
Expand Down Expand Up @@ -131,7 +163,11 @@ def ConversationalReply(self, request, context):


def serve(inference_pipeline, port):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
options=[('grpc.max_send_message_length',
GRPC_MAX_MSG_SIZE),
('grpc.max_receive_message_length',
GRPC_MAX_MSG_SIZE)])
modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(
ModelResponse(inference_pipeline),
server)
Expand Down
12 changes: 10 additions & 2 deletions mii/grpc_related/proto/modelresponse.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@ option objc_class_prefix = "HLW";*/

package modelresponse;

// The greeting service definition.
service ModelResponse {
// Sends a greeting
rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {}
rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {}
rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {}
rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {}
rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {}
rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {}
rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {}
}

message Value {
Expand Down Expand Up @@ -84,3 +83,12 @@ message ConversationReply {
float time_taken = 4;
float model_time_taken = 5;
}

message ImageReply {
repeated bytes images = 1;
repeated bool nsfw_content_detected = 2;
string mode = 3;
int64 size_w = 4;
int64 size_h = 5;
float time_taken = 6;
}
Loading