Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:

- name: Install MII
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git
pip install .[dev,local]

- name: Unit tests
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,9 @@ repos:
--check-filenames,
--check-hidden
]

- repo: https://github.com/pycqa/flake8
rev: 4.0.1
hooks:
- id: flake8
args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401']
21 changes: 21 additions & 0 deletions examples/aml/text-generation-bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import mii

mii_configs = {
"dtype": "fp16",
"tensor_parallel": 8,
"port_number": 50050,
"checkpoint_dict": {
"checkpoints": [f'bloom-mp_0{i}.pt' for i in range(0,
8)],
"parallelization": "tp",
"version": 1.0,
"type": "BLOOM"
}
}
name = "bigscience/bloom"

mii.deploy(task='text-generation',
model=name,
deployment_name=name + "_deployment",
deployment_type=mii.constants.DeploymentType.AML,
mii_config=mii_configs)
12 changes: 12 additions & 0 deletions examples/aml/text-generation-bloom350m-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import mii

mii_configs = {
"tensor_parallel": 1,
"dtype": "fp16",
"aml_model_path": "models/bloom-350m"
}
mii.deploy(task='text-generation',
model="bigscience/bloom-350m",
deployment_name="bloom350m_deployment",
deployment_type=mii.constants.DeploymentType.AML,
mii_config=mii_configs)
4 changes: 0 additions & 4 deletions examples/local/conversational-query-example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from email import generator
import os
import grpc

import mii

# gpt2
Expand Down
3 changes: 0 additions & 3 deletions examples/local/fill-mask-query-example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import grpc

import mii

# roberta
Expand Down
3 changes: 0 additions & 3 deletions examples/local/question-answering-query-example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import grpc

import mii

name = "deepset/roberta-large-squad2"
Expand Down
3 changes: 0 additions & 3 deletions examples/local/text-classification-query-example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import grpc

import mii

# gpt2
Expand Down
15 changes: 13 additions & 2 deletions examples/local/text-generation-bloom-example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import mii

mii_configs = {"dtype": "fp16", "tensor_parallel": 8}
mii_configs = {
"dtype": "fp16",
"tensor_parallel": 8,
"port_number": 50950,
"checkpoint_dict": {
"checkpoints": [f'bloom-mp_0{i}.pt' for i in range(0,
8)],
"parallelization": "tp",
"version": 1.0,
"type": "BLOOM"
}
}
name = "bigscience/bloom"

mii.deploy(task='text-generation',
model=name,
deployment_name=name + "_deployment",
local_model_path="/tmp/huggingface/transformers/",
model_path="/data/bloom-mp",
mii_config=mii_configs)
3 changes: 0 additions & 3 deletions examples/local/token-classification-query-example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import grpc

import mii

# roberta
Expand Down
7 changes: 1 addition & 6 deletions mii/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import enum
import grpc
from .server_client import MIIServerClient, mii_query_handle
from .deployment import deploy
from .terminate import terminate
from .config import MIIConfig
from .constants import DeploymentType, Tasks

from .utils import get_model_path, import_score_file, set_model_path
from .utils import setup_task, get_task, get_task_name, check_if_task_and_model_is_supported, check_if_task_and_model_is_valid
from .config import MIIConfig
from .grpc_related.proto import modelresponse_pb2_grpc
from .grpc_related.proto import modelresponse_pb2
from .models.load_models import load_models
17 changes: 16 additions & 1 deletion mii/config.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,34 @@
import torch
from pydantic import BaseModel, validator, ValidationError
from typing import Union
from pydantic import BaseModel, validator


class MIIConfig(BaseModel):
tensor_parallel: int = 1
port_number: int = 50050
dtype: str = "float"
enable_cuda_graph: bool = False
checkpoint_dict: Union[dict, None] = None

@validator('dtype')
def dtype_valid(cls, value):
# parse dtype value to determine torch dtype
MIIConfig._torch_dtype(value)
return value.lower()

@validator('checkpoint_dict')
def checkpoint_dict_valid(cls, value):
if value is None:
return value
if value.get('base_dir', ''):
raise ValueError(
"please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'"
)
for k in ['checkpoints', 'parallelization', 'version', 'type']:
if not value.get(k, ''):
raise ValueError(f"Missing key={k} in checkpoint_dict")
return value

@staticmethod
def _torch_dtype(value):
value = value.lower()
Expand Down
16 changes: 16 additions & 0 deletions mii/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,28 @@ class ModelProvider(enum.Enum):
CONVERSATIONAL_NAME
]

REQUIRED_KEYS_PER_TASK = {
TEXT_GENERATION_NAME: ["query"],
TEXT_CLASSIFICATION_NAME: ["query"],
QUESTION_ANSWERING_NAME: ["context",
"question"],
FILL_MASK_NAME: ["query"],
TOKEN_CLASSIFICATION_NAME: ["query"],
CONVERSATIONAL_NAME:
['text',
'conversation_id',
'past_user_inputs',
'generated_responses']
}

MODEL_NAME_KEY = 'model_name'
TASK_NAME_KEY = 'task_name'
MODEL_PATH_KEY = 'model_path'

ENABLE_DEEPSPEED_KEY = 'ds_optimize'
ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero'
DEEPSPEED_CONFIG_KEY = 'ds_config'
CHECKPOINT_KEY = "checkpoint"

MII_CACHE_PATH = "MII_CACHE_PATH"
MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache"
Expand Down
88 changes: 29 additions & 59 deletions mii/deployment.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,20 @@
'''
Copyright 2022 The Microsoft DeepSpeed Team
'''
import os
import torch

import mii
from mii import utils
from mii.constants import DeploymentType
from mii.utils import logger, log_levels
from mii.models.utils import download_model_and_get_path
import pprint
import inspect


def create_score_file(deployment_name,
task,
model_name,
ds_optimize,
ds_zero,
ds_config,
mii_configs):
config_dict = {}
config_dict[mii.constants.TASK_NAME_KEY] = mii.get_task_name(task)
config_dict[mii.constants.MODEL_NAME_KEY] = model_name
config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize
config_dict[mii.constants.MII_CONFIGS_KEY] = mii_configs.dict()
config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero
config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config

if len(mii.__path__) > 1:
logger.warning(
f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior"
)

with open(os.path.join(mii.__path__[0], "models/generic_model/score.py"), "r") as fd:
score_src = fd.read()

# update score file w. global config dict
source_with_config = utils.debug_score_preamble()
source_with_config += f"{score_src}\n"
source_with_config += f"configs = {pprint.pformat(config_dict,indent=4)}"

with open(utils.generated_score_path(deployment_name), "w") as fd:
fd.write(source_with_config)
fd.write("\n")

from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT
from .utils import logger
from .models.score import create_score_file, generated_score_path


def deploy(task,
model,
deployment_name,
deployment_type=DeploymentType.LOCAL,
local_model_path=None,
model_path=None,
enable_deepspeed=True,
enable_zero=False,
ds_config=None,
Expand All @@ -74,8 +38,8 @@ def deploy(task,
*``LOCAL`` uses a grpc server to create a local deployment, and query the model must be done by creating a query handle using
`mii.mii_query_handle` and posting queries using ``mii_request_handle.query`` API,

local_model_path: Optional: Local folder where the model checkpoints are available.
This should be provided if you want to use your own checkpoint instead of the default open-source checkpoints for the supported models for `LOCAL` deployment.
model_path: Optional: In LOCAL deployments this is the local path where model checkpoints are available. In AML deployments this
is an optional relative path with AZURE_MODEL_DIR for the deployment.

enable_deepspeed: Optional: Defaults to True. Use this flag to enable or disable DeepSpeed-Inference optimizations

Expand All @@ -84,7 +48,7 @@ def deploy(task,
ds_config: Optional: Defaults to None. Use this to specify the DeepSpeed configuration when enabling DeepSpeed-ZeRO inference

force_register_model: Optional: Defaults to False. For AML deployments, set it to True if you want to re-register your model
with the same ``aml_model_tags`` using checkpoints from ``local_model_path``.
with the same ``aml_model_tags`` using checkpoints from ``model_path``.

mii_config: Optional: Dictionary specifying optimization and deployment configurations that should override defaults in ``mii.config.MIIConfig``.
mii_config is future looking to support extensions in optimization strategies supported by DeepSpeed Inference as we extend mii.
Expand All @@ -102,29 +66,35 @@ def deploy(task,
assert (mii_config.torch_dtype() == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match"
assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one"

task = mii.get_task(task)
mii.check_if_task_and_model_is_valid(task, model)
task = mii.utils.get_task(task)
mii.utils.check_if_task_and_model_is_valid(task, model)
if enable_deepspeed:
mii.check_if_task_and_model_is_supported(task, model)
mii.utils.check_if_task_and_model_is_supported(task, model)

logger.info(f"*************DeepSpeed Optimizations: {enable_deepspeed}*************")

create_score_file(deployment_name,
task,
model,
enable_deepspeed,
enable_zero,
ds_config,
mii_config)
# In local deployments use default path if no model path set
if model_path is None and deployment_type == DeploymentType.LOCAL:
model_path = MII_MODEL_PATH_DEFAULT
elif model_path is None and deployment_type == DeploymentType.AML:
model_path = None

create_score_file(deployment_name=deployment_name,
task=task,
model_name=model,
ds_optimize=enable_deepspeed,
ds_zero=enable_zero,
ds_config=ds_config,
mii_config=mii_config,
model_path=model_path)

if deployment_type == DeploymentType.AML:
print(f"Score file created at {utils.generated_score_path(deployment_name)}")
print(f"Score file created at {generated_score_path(deployment_name)}")
elif deployment_type == DeploymentType.LOCAL:
return _deploy_local(deployment_name, local_model_path=local_model_path)
return _deploy_local(deployment_name, model_path=model_path)
else:
raise Exception(f"Unknown deployment type: {deployment_type}")


def _deploy_local(deployment_name, local_model_path=None):
mii.set_model_path(local_model_path)
mii.import_score_file(deployment_name).init()
def _deploy_local(deployment_name, model_path):
mii.utils.import_score_file(deployment_name).init()
2 changes: 2 additions & 0 deletions mii/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .score import create_score_file
from .load_models import load_models
27 changes: 0 additions & 27 deletions mii/models/gpt2/score.py

This file was deleted.

Loading