diff --git a/.github/workflows/cpu.yml b/.github/workflows/cpu.yml index 01ae8210..743e243d 100644 --- a/.github/workflows/cpu.yml +++ b/.github/workflows/cpu.yml @@ -29,6 +29,7 @@ jobs: - name: Install MII run: | + pip install git+https://github.com/microsoft/DeepSpeed.git pip install .[dev,local] - name: Unit tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43407779..5e0d9702 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,3 +37,9 @@ repos: --check-filenames, --check-hidden ] + +- repo: https://github.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + args: ['--ignore=E,F403,F405,F541,F841,W', '--select=E9,F,W6', '--per-file-ignores=__init__.py:F401'] diff --git a/examples/aml/text-generation-bloom.py b/examples/aml/text-generation-bloom.py new file mode 100644 index 00000000..bd36098f --- /dev/null +++ b/examples/aml/text-generation-bloom.py @@ -0,0 +1,21 @@ +import mii + +mii_configs = { + "dtype": "fp16", + "tensor_parallel": 8, + "port_number": 50050, + "checkpoint_dict": { + "checkpoints": [f'bloom-mp_0{i}.pt' for i in range(0, + 8)], + "parallelization": "tp", + "version": 1.0, + "type": "BLOOM" + } +} +name = "bigscience/bloom" + +mii.deploy(task='text-generation', + model=name, + deployment_name=name + "_deployment", + deployment_type=mii.constants.DeploymentType.AML, + mii_config=mii_configs) diff --git a/examples/aml/text-generation-bloom350m-example.py b/examples/aml/text-generation-bloom350m-example.py new file mode 100644 index 00000000..85195598 --- /dev/null +++ b/examples/aml/text-generation-bloom350m-example.py @@ -0,0 +1,12 @@ +import mii + +mii_configs = { + "tensor_parallel": 1, + "dtype": "fp16", + "aml_model_path": "models/bloom-350m" +} +mii.deploy(task='text-generation', + model="bigscience/bloom-350m", + deployment_name="bloom350m_deployment", + deployment_type=mii.constants.DeploymentType.AML, + mii_config=mii_configs) diff --git a/examples/local/conversational-query-example.py b/examples/local/conversational-query-example.py index 09f316cd..c18a026c 100644 --- a/examples/local/conversational-query-example.py +++ b/examples/local/conversational-query-example.py @@ -1,7 +1,3 @@ -from email import generator -import os -import grpc - import mii # gpt2 diff --git a/examples/local/fill-mask-query-example.py b/examples/local/fill-mask-query-example.py index d6dcbc02..aa8b1019 100644 --- a/examples/local/fill-mask-query-example.py +++ b/examples/local/fill-mask-query-example.py @@ -1,6 +1,3 @@ -import os -import grpc - import mii # roberta diff --git a/examples/local/question-answering-query-example.py b/examples/local/question-answering-query-example.py index 8f1e6c0c..61530a55 100644 --- a/examples/local/question-answering-query-example.py +++ b/examples/local/question-answering-query-example.py @@ -1,6 +1,3 @@ -import os -import grpc - import mii name = "deepset/roberta-large-squad2" diff --git a/examples/local/text-classification-query-example.py b/examples/local/text-classification-query-example.py index 5e936054..7e387b94 100644 --- a/examples/local/text-classification-query-example.py +++ b/examples/local/text-classification-query-example.py @@ -1,6 +1,3 @@ -import os -import grpc - import mii # gpt2 diff --git a/examples/local/text-generation-bloom-example.py b/examples/local/text-generation-bloom-example.py index f6625e1b..8ada208f 100644 --- a/examples/local/text-generation-bloom-example.py +++ b/examples/local/text-generation-bloom-example.py @@ -1,10 +1,21 @@ import mii -mii_configs = {"dtype": "fp16", "tensor_parallel": 8} +mii_configs = { + "dtype": "fp16", + "tensor_parallel": 8, + "port_number": 50950, + "checkpoint_dict": { + "checkpoints": [f'bloom-mp_0{i}.pt' for i in range(0, + 8)], + "parallelization": "tp", + "version": 1.0, + "type": "BLOOM" + } +} name = "bigscience/bloom" mii.deploy(task='text-generation', model=name, deployment_name=name + "_deployment", - local_model_path="/tmp/huggingface/transformers/", + model_path="/data/bloom-mp", mii_config=mii_configs) diff --git a/examples/local/token-classification-query-example.py b/examples/local/token-classification-query-example.py index 08226a1b..a0a69c56 100644 --- a/examples/local/token-classification-query-example.py +++ b/examples/local/token-classification-query-example.py @@ -1,6 +1,3 @@ -import os -import grpc - import mii # roberta diff --git a/mii/__init__.py b/mii/__init__.py index cbdabc1d..af2173f5 100644 --- a/mii/__init__.py +++ b/mii/__init__.py @@ -1,13 +1,8 @@ -import enum import grpc from .server_client import MIIServerClient, mii_query_handle from .deployment import deploy from .terminate import terminate -from .config import MIIConfig from .constants import DeploymentType, Tasks -from .utils import get_model_path, import_score_file, set_model_path -from .utils import setup_task, get_task, get_task_name, check_if_task_and_model_is_supported, check_if_task_and_model_is_valid +from .config import MIIConfig from .grpc_related.proto import modelresponse_pb2_grpc -from .grpc_related.proto import modelresponse_pb2 -from .models.load_models import load_models diff --git a/mii/config.py b/mii/config.py index d9d8b10a..3e2f6b47 100644 --- a/mii/config.py +++ b/mii/config.py @@ -1,5 +1,6 @@ import torch -from pydantic import BaseModel, validator, ValidationError +from typing import Union +from pydantic import BaseModel, validator class MIIConfig(BaseModel): @@ -7,6 +8,7 @@ class MIIConfig(BaseModel): port_number: int = 50050 dtype: str = "float" enable_cuda_graph: bool = False + checkpoint_dict: Union[dict, None] = None @validator('dtype') def dtype_valid(cls, value): @@ -14,6 +16,19 @@ def dtype_valid(cls, value): MIIConfig._torch_dtype(value) return value.lower() + @validator('checkpoint_dict') + def checkpoint_dict_valid(cls, value): + if value is None: + return value + if value.get('base_dir', ''): + raise ValueError( + "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" + ) + for k in ['checkpoints', 'parallelization', 'version', 'type']: + if not value.get(k, ''): + raise ValueError(f"Missing key={k} in checkpoint_dict") + return value + @staticmethod def _torch_dtype(value): value = value.lower() diff --git a/mii/constants.py b/mii/constants.py index 4f2f05cc..d523be41 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -62,12 +62,28 @@ class ModelProvider(enum.Enum): CONVERSATIONAL_NAME ] +REQUIRED_KEYS_PER_TASK = { + TEXT_GENERATION_NAME: ["query"], + TEXT_CLASSIFICATION_NAME: ["query"], + QUESTION_ANSWERING_NAME: ["context", + "question"], + FILL_MASK_NAME: ["query"], + TOKEN_CLASSIFICATION_NAME: ["query"], + CONVERSATIONAL_NAME: + ['text', + 'conversation_id', + 'past_user_inputs', + 'generated_responses'] +} + MODEL_NAME_KEY = 'model_name' TASK_NAME_KEY = 'task_name' +MODEL_PATH_KEY = 'model_path' ENABLE_DEEPSPEED_KEY = 'ds_optimize' ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero' DEEPSPEED_CONFIG_KEY = 'ds_config' +CHECKPOINT_KEY = "checkpoint" MII_CACHE_PATH = "MII_CACHE_PATH" MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache" diff --git a/mii/deployment.py b/mii/deployment.py index 4aa2d87f..b505a5e7 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -1,56 +1,20 @@ ''' Copyright 2022 The Microsoft DeepSpeed Team ''' -import os import torch import mii -from mii import utils -from mii.constants import DeploymentType -from mii.utils import logger, log_levels -from mii.models.utils import download_model_and_get_path -import pprint -import inspect - - -def create_score_file(deployment_name, - task, - model_name, - ds_optimize, - ds_zero, - ds_config, - mii_configs): - config_dict = {} - config_dict[mii.constants.TASK_NAME_KEY] = mii.get_task_name(task) - config_dict[mii.constants.MODEL_NAME_KEY] = model_name - config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize - config_dict[mii.constants.MII_CONFIGS_KEY] = mii_configs.dict() - config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero - config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config - - if len(mii.__path__) > 1: - logger.warning( - f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" - ) - - with open(os.path.join(mii.__path__[0], "models/generic_model/score.py"), "r") as fd: - score_src = fd.read() - - # update score file w. global config dict - source_with_config = utils.debug_score_preamble() - source_with_config += f"{score_src}\n" - source_with_config += f"configs = {pprint.pformat(config_dict,indent=4)}" - - with open(utils.generated_score_path(deployment_name), "w") as fd: - fd.write(source_with_config) - fd.write("\n") + +from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT +from .utils import logger +from .models.score import create_score_file, generated_score_path def deploy(task, model, deployment_name, deployment_type=DeploymentType.LOCAL, - local_model_path=None, + model_path=None, enable_deepspeed=True, enable_zero=False, ds_config=None, @@ -74,8 +38,8 @@ def deploy(task, *``LOCAL`` uses a grpc server to create a local deployment, and query the model must be done by creating a query handle using `mii.mii_query_handle` and posting queries using ``mii_request_handle.query`` API, - local_model_path: Optional: Local folder where the model checkpoints are available. - This should be provided if you want to use your own checkpoint instead of the default open-source checkpoints for the supported models for `LOCAL` deployment. + model_path: Optional: In LOCAL deployments this is the local path where model checkpoints are available. In AML deployments this + is an optional relative path with AZURE_MODEL_DIR for the deployment. enable_deepspeed: Optional: Defaults to True. Use this flag to enable or disable DeepSpeed-Inference optimizations @@ -84,7 +48,7 @@ def deploy(task, ds_config: Optional: Defaults to None. Use this to specify the DeepSpeed configuration when enabling DeepSpeed-ZeRO inference force_register_model: Optional: Defaults to False. For AML deployments, set it to True if you want to re-register your model - with the same ``aml_model_tags`` using checkpoints from ``local_model_path``. + with the same ``aml_model_tags`` using checkpoints from ``model_path``. mii_config: Optional: Dictionary specifying optimization and deployment configurations that should override defaults in ``mii.config.MIIConfig``. mii_config is future looking to support extensions in optimization strategies supported by DeepSpeed Inference as we extend mii. @@ -102,29 +66,35 @@ def deploy(task, assert (mii_config.torch_dtype() == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match" assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one" - task = mii.get_task(task) - mii.check_if_task_and_model_is_valid(task, model) + task = mii.utils.get_task(task) + mii.utils.check_if_task_and_model_is_valid(task, model) if enable_deepspeed: - mii.check_if_task_and_model_is_supported(task, model) + mii.utils.check_if_task_and_model_is_supported(task, model) logger.info(f"*************DeepSpeed Optimizations: {enable_deepspeed}*************") - create_score_file(deployment_name, - task, - model, - enable_deepspeed, - enable_zero, - ds_config, - mii_config) + # In local deployments use default path if no model path set + if model_path is None and deployment_type == DeploymentType.LOCAL: + model_path = MII_MODEL_PATH_DEFAULT + elif model_path is None and deployment_type == DeploymentType.AML: + model_path = None + + create_score_file(deployment_name=deployment_name, + task=task, + model_name=model, + ds_optimize=enable_deepspeed, + ds_zero=enable_zero, + ds_config=ds_config, + mii_config=mii_config, + model_path=model_path) if deployment_type == DeploymentType.AML: - print(f"Score file created at {utils.generated_score_path(deployment_name)}") + print(f"Score file created at {generated_score_path(deployment_name)}") elif deployment_type == DeploymentType.LOCAL: - return _deploy_local(deployment_name, local_model_path=local_model_path) + return _deploy_local(deployment_name, model_path=model_path) else: raise Exception(f"Unknown deployment type: {deployment_type}") -def _deploy_local(deployment_name, local_model_path=None): - mii.set_model_path(local_model_path) - mii.import_score_file(deployment_name).init() +def _deploy_local(deployment_name, model_path): + mii.utils.import_score_file(deployment_name).init() diff --git a/mii/models/__init__.py b/mii/models/__init__.py index e69de29b..1fe8d9ad 100644 --- a/mii/models/__init__.py +++ b/mii/models/__init__.py @@ -0,0 +1,2 @@ +from .score import create_score_file +from .load_models import load_models diff --git a/mii/models/gpt2/score.py b/mii/models/gpt2/score.py deleted file mode 100644 index 4780a5e4..00000000 --- a/mii/models/gpt2/score.py +++ /dev/null @@ -1,27 +0,0 @@ -''' -Copyright 2022 The Microsoft DeepSpeed Team -''' -import os -import mii - -generator = None - - -def init(): - - #TODO set the parallelism degree somehow. On the azure kubernetes server we can set the gpu core - #but how do we do this in local deployment? - os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" - model_path, use_grpc_server, initialize_grpc_client = mii.setup_generation_task() - - generator = mii.MIIGenerationServerClient( - 'gpt2', - model_path, - use_grpc_server=use_grpc_server, - initialize_grpc_client=initialize_grpc_client) - - -def run(request): - global generator - text = json.loads(request) - return generator.query(text["query"]) diff --git a/mii/models/load_models.py b/mii/models/load_models.py index d872a7b3..7eae90aa 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -3,163 +3,11 @@ ''' import os import mii -from pathlib import Path +import json import torch import deepspeed -from deepspeed.runtime.zero.constants import * - - -def check_zero_ds_config(config): - config_zero = config.get(ZERO_OPTIMIZATION, {}) - stage = config_zero.get(ZERO_OPTIMIZATION_STAGE, None) - if stage != ZERO_OPTIMIZATION_WEIGHTS: - assert False, "DeepSpeed ZeRO inference is only supported for ZeRO 3 optimization stage" - - -def hf_provider(model_path, model_name, task_name, mii_config): - local_rank = int(os.getenv('LOCAL_RANK', '0')) - from transformers import pipeline - inference_pipeline = pipeline(task_name, - model=model_name, - device=local_rank, - framework="pt") - if mii_config.torch_dtype() == torch.half: - inference_pipeline.model.half() - return inference_pipeline - - -def eleutherai_provider(model_path, model_name, task_name, mii_config): - world_size = int(os.getenv('WORLD_SIZE', '1')) - from megatron.neox_pipeline import NeoXPipeline - config = { - "load": model_path, - "vocab_file": os.path.join(model_path, - "20B_tokenizer.json"), - "model_parallel_size": world_size - } - return NeoXPipeline(config) - - -''' -TODO: The following class and functions are non-optimal (i.e., hacky) solutions -to getting the Bloom models working and will be refactored in a future PR -''' - - -class BloomPipeline(object): - def __init__(self, model, tokenizer): - self.model = model - self.tokenizer = tokenizer - - def __call__(self, inputs, **kwargs): - local_rank = int(os.getenv('LOCAL_RANK', '0')) - torch.cuda.set_device(local_rank) - from deepspeed.inference.engine import InferenceEngine - if isinstance(self.model, InferenceEngine): - self.model = self.model.module - - # expand proto list into py-list - inputs = [i for i in inputs] - tokens = self.tokenizer.batch_encode_plus(inputs, - return_tensors="pt", - padding=True) - for t in tokens: - if torch.is_tensor(tokens[t]): - tokens[t] = tokens[t].to(f'cuda:{local_rank}') - greedy_output = self.model.generate(**tokens, **kwargs) - outputs = self.tokenizer.batch_decode(greedy_output, skip_special_tokens=True) - - # construct output to align w. HF pipeline - output_dicts = [] - for output in outputs: - output_dicts.append([{'generated_text': output}]) - - return output_dicts - - -def get_checkpoint_files(pretrained_model_name_or_path): - from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, cached_path, hf_bucket_url, is_offline_mode - from transformers.utils.hub import EntryNotFoundError - from transformers.modeling_utils import get_checkpoint_shard_files - - cache_dir = None - is_sharded = False - revision = None - local_files_only = False - - filename = WEIGHTS_NAME - archive_file = hf_bucket_url(pretrained_model_name_or_path, - filename=filename, - revision=revision) - - try: - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - local_files_only=local_files_only, - ) - return [resolved_archive_file] - - except (EntryNotFoundError, FileNotFoundError): - if filename == WEIGHTS_NAME: - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - archive_file = hf_bucket_url( - pretrained_model_name_or_path, - filename=WEIGHTS_INDEX_NAME, - revision=revision, - ) - resolved_archive_file = cached_path( - archive_file, - cache_dir=cache_dir, - local_files_only=local_files_only, - ) - is_sharded = True - - if is_sharded: - # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - revision=revision - ) - - return resolved_archive_file - - -def write_checkponts_json(model_name): - import io - import json - checkpoints_json = "checkpoints.json" - with io.open(checkpoints_json, 'w', encoding='utf-8') as f: - - checkpoint_files = get_checkpoint_files(model_name) - - data = {"type": "BLOOM-176B", "checkpoints": checkpoint_files, "version": 1.0} - json.dump(data, f) - - -# TODO: This function is a hack for the Bloom models and will be replaced with a LargeModel provider code path -def load_hf_llm(model_path, model_name, task_name, mii_config): - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - from transformers import pipeline - import torch.distributed as dist - from deepspeed import OnDevice - - deepspeed.init_distributed('nccl') - local_rank = int(os.getenv('LOCAL_RANK', '0')) - world_size = int(os.getenv('WORLD_SIZE', '1')) - - tokenizer = AutoTokenizer.from_pretrained(model_name) - config = AutoConfig.from_pretrained(model_name) - with OnDevice(dtype=torch.float16, enabled=True): - model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) - model = model.eval() - if local_rank == 0: - write_checkponts_json(model_name) - dist.barrier() - inference_pipeline = BloomPipeline(model=model, tokenizer=tokenizer) - return inference_pipeline +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.runtime.zero.config import ZeroStageEnum def load_models(task_name, @@ -180,12 +28,13 @@ def load_models(task_name, args = None training_mp_size = 1 if provider == mii.constants.ModelProvider.HUGGING_FACE: + from mii.models.providers.huggingface import hf_provider inference_pipeline = hf_provider(model_path, model_name, task_name, mii_config) elif provider == mii.constants.ModelProvider.ELEUTHER_AI: + from mii.models.providers.eleutherai import eleutherai_provider assert mii_config.torch_dtype() == torch.half, "gpt-neox only support fp16" assert mii_config.enable_cuda_graph == False, "Provider EleutherAI not supported with Cuda Graphs" from megatron import mpu - from argparse import Namespace inference_pipeline = eleutherai_provider(model_path, model_name, task_name, @@ -193,10 +42,11 @@ def load_models(task_name, training_mp_size = 2 args = inference_pipeline.neox_args elif provider == mii.constants.ModelProvider.HUGGING_FACE_LLM: + from mii.models.providers.llm import load_hf_llm assert mii_config.torch_dtype() == torch.half, "Bloom models only support fp16" assert mii_config.enable_cuda_graph == False, "Bloom models do no support Cuda Graphs" inference_pipeline = load_hf_llm(model_path, model_name, task_name, mii_config) - checkpoint = "checkpoints.json" + checkpoint = inference_pipeline.checkpoint_dict else: raise ValueError(f"Unknown model provider {provider}") @@ -213,14 +63,14 @@ def load_models(task_name, enable_cuda_graph=mii_config.enable_cuda_graph, args=args) elif ds_zero: - assert os.path.exists(ds_config_path), '{ds_config_path} does not exist' - import json - ds_config = json.load(open(ds_config_path, "r")) - check_zero_ds_config(ds_config) + ds_config = DeepSpeedConfig(ds_config_path) + #TODO: don't read ds-config from disk, we should pass this around as a dict instead + ds_config_dict = json.load(open(ds_config_path, 'r')) + assert ds_config.zero_optimization_stage == ZeroStageEnum.weights, "DeepSpeed ZeRO inference is only supported for ZeRO-3" # initialise Deepspeed ZeRO and store only the engine object ds_engine = deepspeed.initialize(model=inference_pipeline.model, - config_params=ds_config)[0] + config_params=ds_config_dict)[0] ds_engine.module.eval() # inference inference_pipeline.model = ds_engine.module return inference_pipeline diff --git a/mii/models/generic_model/__init__.py b/mii/models/providers/__init__.py similarity index 100% rename from mii/models/generic_model/__init__.py rename to mii/models/providers/__init__.py diff --git a/mii/models/gpt2/__init__.py b/mii/models/providers/eleutherai.py similarity index 100% rename from mii/models/gpt2/__init__.py rename to mii/models/providers/eleutherai.py diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py new file mode 100644 index 00000000..0b1baf45 --- /dev/null +++ b/mii/models/providers/huggingface.py @@ -0,0 +1,14 @@ +import os +import torch +from transformers import pipeline + + +def hf_provider(model_path, model_name, task_name, mii_config): + local_rank = int(os.getenv('LOCAL_RANK', '0')) + inference_pipeline = pipeline(task_name, + model=model_name, + device=local_rank, + framework="pt") + if mii_config.torch_dtype() == torch.half: + inference_pipeline.model.half() + return inference_pipeline diff --git a/mii/models/providers/llm.py b/mii/models/providers/llm.py new file mode 100644 index 00000000..e26b9aef --- /dev/null +++ b/mii/models/providers/llm.py @@ -0,0 +1,120 @@ +import os +import torch +import deepspeed +from deepspeed.inference.engine import InferenceEngine +from deepspeed import OnDevice + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig +from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, cached_path, hf_bucket_url +from transformers.utils.hub import EntryNotFoundError +from transformers.modeling_utils import get_checkpoint_shard_files +''' +TODO: The following class and functions are non-optimal (i.e., hacky) solutions +to getting the Bloom models working and will be refactored in a future PR +''' + + +class BloomPipeline(object): + def __init__(self, model, tokenizer, checkpoint_dict): + self.model = model + self.tokenizer = tokenizer + self.checkpoint_dict = checkpoint_dict + + def __call__(self, inputs, **kwargs): + local_rank = int(os.getenv('LOCAL_RANK', '0')) + torch.cuda.set_device(local_rank) + if isinstance(self.model, InferenceEngine): + self.model = self.model.module + + # expand proto list into py-list + inputs = [i for i in inputs] + tokens = self.tokenizer.batch_encode_plus(inputs, + return_tensors="pt", + padding=True) + for t in tokens: + if torch.is_tensor(tokens[t]): + tokens[t] = tokens[t].to(f'cuda:{local_rank}') + greedy_output = self.model.generate(**tokens, **kwargs) + outputs = self.tokenizer.batch_decode(greedy_output, skip_special_tokens=True) + + # construct output to align w. HF pipeline + output_dicts = [] + for output in outputs: + output_dicts.append([{'generated_text': output}]) + + return output_dicts + + +def get_checkpoint_files(pretrained_model_name_or_path): + cache_dir = None + is_sharded = False + revision = None + local_files_only = False + + filename = WEIGHTS_NAME + archive_file = hf_bucket_url(pretrained_model_name_or_path, + filename=filename, + revision=revision) + + try: + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + return [resolved_archive_file] + + except (EntryNotFoundError, FileNotFoundError): + if filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + local_files_only=local_files_only, + ) + is_sharded = True + + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + revision=revision + ) + + return resolved_archive_file + + +def create_checkpoint_dict(model_name, model_path, mii_config): + if mii_config.checkpoint_dict: + mii_config.checkpoint_dict['base_dir'] = model_path + return mii_config.checkpoint_dict + else: + checkpoint_files = get_checkpoint_files(model_name) + data = {"type": "BLOOM", "checkpoints": checkpoint_files, "version": 1.0} + return data + + +# TODO: This function is a hack for the Bloom models and will be replaced with a LargeModel provider code path +def load_hf_llm(model_path, model_name, task_name, mii_config): + deepspeed.init_distributed('nccl') + local_rank = int(os.getenv('LOCAL_RANK', '0')) + world_size = int(os.getenv('WORLD_SIZE', '1')) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + with OnDevice(dtype=torch.float16, enabled=True): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) + model = model.eval() + checkpoint_dict = create_checkpoint_dict(model_name, model_path, mii_config) + torch.distributed.barrier() + inference_pipeline = BloomPipeline(model=model, + tokenizer=tokenizer, + checkpoint_dict=checkpoint_dict) + return inference_pipeline diff --git a/mii/models/score/__init__.py b/mii/models/score/__init__.py new file mode 100644 index 00000000..224b2522 --- /dev/null +++ b/mii/models/score/__init__.py @@ -0,0 +1 @@ +from .generate import create_score_file, generated_score_path diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py new file mode 100644 index 00000000..70f36823 --- /dev/null +++ b/mii/models/score/generate.py @@ -0,0 +1,50 @@ +''' +Copyright 2022 The Microsoft DeepSpeed Team +''' +import os +import mii +import pprint +from mii.utils import logger + + +def create_score_file(deployment_name, + task, + model_name, + ds_optimize, + ds_zero, + ds_config, + mii_config, + model_path): + config_dict = {} + config_dict[mii.constants.TASK_NAME_KEY] = mii.utils.get_task_name(task) + config_dict[mii.constants.MODEL_NAME_KEY] = model_name + config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize + config_dict[mii.constants.MII_CONFIGS_KEY] = mii_config.dict() + config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero + config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config + config_dict[mii.constants.MODEL_PATH_KEY] = model_path + + if len(mii.__path__) > 1: + logger.warning( + f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" + ) + + with open(os.path.join(mii.__path__[0], + "models/score/score_template.py"), + "r") as fd: + score_src = fd.read() + + # update score file w. global config dict + source_with_config = f"{score_src}\n" + source_with_config += f"configs = {pprint.pformat(config_dict, indent=4)}" + + with open(generated_score_path(deployment_name), "w") as fd: + fd.write(source_with_config) + fd.write("\n") + + +def generated_score_path(deployment_name): + score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) + if not os.path.isdir(score_path): + os.makedirs(score_path) + return os.path.join(score_path, "score.py") diff --git a/mii/models/generic_model/score.py b/mii/models/score/score_template.py similarity index 74% rename from mii/models/generic_model/score.py rename to mii/models/score/score_template.py index 24baa56f..18bbf381 100644 --- a/mii/models/generic_model/score.py +++ b/mii/models/score/score_template.py @@ -1,3 +1,4 @@ +# flake8: noqa ''' Copyright 2022 The Microsoft DeepSpeed Team ''' @@ -10,7 +11,13 @@ def init(): - model_path, use_grpc_server, initialize_grpc_client = mii.setup_task() + # In AML deployments both the GRPC client and server are used in the same process + initialize_grpc_client = mii.utils.is_aml() + + # XXX: Always run grpc server, originally was "not is_aml()" + use_grpc_server = True + + model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) model_name = configs[mii.constants.MODEL_NAME_KEY] task = configs[mii.constants.TASK_NAME_KEY] @@ -34,9 +41,8 @@ def run(request): global model request_dict = json.loads(request) - query_dict = request_dict.pop('query', None) - if query_dict is None: - return "Missing 'query' key in request" + query_dict = mii.utils.extract_query_dict(configs[mii.constants.TASK_NAME_KEY], + request_dict) response = model.query(query_dict, **request_dict) @@ -44,4 +50,5 @@ def run(request): response = [r for r in response.response] return json.dumps({'responses': response}) + ### Auto-generated config will be appended below at run-time diff --git a/mii/models/utils.py b/mii/models/utils.py index 37ad134f..29eaba7b 100644 --- a/mii/models/utils.py +++ b/mii/models/utils.py @@ -1,6 +1,4 @@ -from ctypes import wstring_at import os -import mii from mii.utils import mii_cache_path diff --git a/mii/server_client.py b/mii/server_client.py index 8559b287..c7a9402a 100644 --- a/mii/server_client.py +++ b/mii/server_client.py @@ -2,7 +2,6 @@ Copyright 2022 The Microsoft DeepSpeed Team ''' import asyncio -from readline import write_history_file import torch import sys import subprocess @@ -13,8 +12,8 @@ from pathlib import Path import mii import base64 -import json from mii.utils import logger, kwarg_dict_to_proto +from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc def mii_query_handle(deployment_name): @@ -63,7 +62,7 @@ def __init__(self, mii_configs = mii.config.MIIConfig(**mii_configs) - self.task = mii.get_task(task_name) + self.task = mii.utils.get_task(task_name) self.num_gpus = self._get_num_gpus(mii_configs) assert self.num_gpus > 0, "GPU count must be greater than 0" @@ -145,13 +144,14 @@ def _initialize_service(self, mii_configs): process = None if not self.use_grpc_server: - self.model = mii.load_models(task_name=mii.get_task_name(self.task), - model_name=model_name, - model_path=model_path, - ds_optimize=ds_optimize, - ds_zero=ds_zero, - ds_config_path=ds_config, - mii_config=mii_configs) + self.model = mii.models.load_models(task_name=mii.utils.get_task_name( + self.task), + model_name=model_name, + model_path=model_path, + ds_optimize=ds_optimize, + ds_zero=ds_zero, + ds_config_path=ds_config, + mii_config=mii_configs) else: if self._is_socket_open(self.port_number): raise RuntimeError( @@ -168,7 +168,7 @@ def _initialize_service(self, ds_launch_str = f"deepspeed --num_gpus {self.num_gpus} --no_local_rank --no_python" launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" - server_args_str = f"--task-name {mii.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {self.port_number}" + server_args_str = f"--task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {self.port_number}" server_args_str += " --ds-optimize" if ds_optimize else "" #XXX: fetch model provider based on model name in a more general way @@ -213,8 +213,7 @@ def _initialize_grpc_client(self): channels = [] for i in range(self.num_gpus): channel = grpc.aio.insecure_channel(f'localhost:{self.port_number + i}') - stub = mii.grpc_related.proto.modelresponse_pb2_grpc.ModelResponseStub( - channel) + stub = modelresponse_pb2_grpc.ModelResponseStub(channel) channels.append(channel) self.stubs.append(stub) @@ -238,33 +237,33 @@ async def _request_async_response(self, stub_id, request_dict, query_kwargs): # convert to batch of queries if they are not already if not isinstance(request_dict['query'], list): request_dict['query'] = [request_dict['query']] - req = mii.modelresponse_pb2.MultiStringRequest(request=request_dict['query'], - query_kwargs=proto_kwargs) + req = modelresponse_pb2.MultiStringRequest(request=request_dict['query'], + query_kwargs=proto_kwargs) response = await self.stubs[stub_id].GeneratorReply(req) elif self.task == mii.Tasks.TEXT_CLASSIFICATION: response = await self.stubs[stub_id].ClassificationReply( - mii.modelresponse_pb2.SingleStringRequest(request=request_dict['query'], - query_kwargs=proto_kwargs)) + modelresponse_pb2.SingleStringRequest(request=request_dict['query'], + query_kwargs=proto_kwargs)) elif self.task == mii.Tasks.QUESTION_ANSWERING: response = await self.stubs[stub_id].QuestionAndAnswerReply( - mii.modelresponse_pb2.QARequest(question=request_dict['question'], - context=request_dict['context'], - query_kwargs=proto_kwargs)) + modelresponse_pb2.QARequest(question=request_dict['question'], + context=request_dict['context'], + query_kwargs=proto_kwargs)) elif self.task == mii.Tasks.FILL_MASK: response = await self.stubs[stub_id].FillMaskReply( - mii.modelresponse_pb2.SingleStringRequest(request=request_dict['query'], - query_kwargs=proto_kwargs)) + modelresponse_pb2.SingleStringRequest(request=request_dict['query'], + query_kwargs=proto_kwargs)) elif self.task == mii.Tasks.TOKEN_CLASSIFICATION: response = await self.stubs[stub_id].TokenClassificationReply( - mii.modelresponse_pb2.SingleStringRequest(request=request_dict['query'], - query_kwargs=proto_kwargs)) + modelresponse_pb2.SingleStringRequest(request=request_dict['query'], + query_kwargs=proto_kwargs)) elif self.task == mii.Tasks.CONVERSATIONAL: response = await self.stubs[stub_id].ConversationalReply( - mii.modelresponse_pb2.ConversationRequest( + modelresponse_pb2.ConversationRequest( text=request_dict['text'], conversation_id=request_dict['conversation_id'] if 'conversation_id' in request_dict else None, @@ -299,7 +298,7 @@ def _request_response(self, request_dict, query_kwargs): response = self.model(["", request_dict['query']], **query_kwargs) else: - raise NotSupportedError(f"task is not supported: {self.task}") + raise NotImplementedError(f"task is not supported: {self.task}") end = time.time() return f"{response}" + f"\n Model Execution Time: {end-start} seconds" diff --git a/mii/terminate.py b/mii/terminate.py index 7f3e1f2f..9779cec6 100644 --- a/mii/terminate.py +++ b/mii/terminate.py @@ -8,7 +8,7 @@ def terminate(deployment_name): mii.utils.logger.info(f"Terminating server for {deployment_name}") generator = mii.mii_query_handle(deployment_name) try: - generator.query({'query': None}) + generator.query({'query': ''}) except grpc.aio._call.AioRpcError as error: if error._code == grpc.StatusCode.UNAVAILABLE: mii.utils.logger.warn(f"Server for {deployment_name} not found") diff --git a/mii/utils.py b/mii/utils.py index 0ee2870a..7a937f58 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -1,19 +1,27 @@ -''' +""" Copyright 2022 The Microsoft DeepSpeed Team -''' +""" import sys import os import logging import importlib import mii -from pathlib import Path from huggingface_hub import HfApi -from mii.constants import CONVERSATIONAL_NAME, FILL_MASK_NAME, MII_CACHE_PATH, MII_CACHE_PATH_DEFAULT, MII_DEBUG_MODE, \ - MII_DEBUG_MODE_DEFAULT, MII_DEBUG_DEPLOY_KEY, MII_DEBUG_BRANCH, MII_DEBUG_BRANCH_DEFAULT, \ - TEXT_GENERATION_NAME, TEXT_CLASSIFICATION_NAME, QUESTION_ANSWERING_NAME, TOKEN_CLASSIFICATION_NAME, SUPPORTED_MODEL_TYPES, \ - ModelProvider, MII_MODEL_PATH_DEFAULT +from mii.constants import ( + CONVERSATIONAL_NAME, + FILL_MASK_NAME, + MII_CACHE_PATH, + MII_CACHE_PATH_DEFAULT, + TEXT_GENERATION_NAME, + TEXT_CLASSIFICATION_NAME, + QUESTION_ANSWERING_NAME, + TOKEN_CLASSIFICATION_NAME, + SUPPORTED_MODEL_TYPES, + ModelProvider, + REQUIRED_KEYS_PER_TASK, +) from mii.constants import Tasks @@ -65,11 +73,11 @@ def get_task(task_name): def _get_hf_models_by_type(model_type, task=None): api = HfApi() models = api.list_models(filter=model_type) - return [m.modelId for m in models - ] if task is None else [m.modelId for m in models if m.pipeline_tag == task] + return ([m.modelId for m in models] + if task is None else [m.modelId for m in models if m.pipeline_tag == task]) -#TODO read this from a file containing list of files supported for each task +# TODO read this from a file containing list of files supported for each task def _get_supported_models_name(task): supported_models = [] task_name = get_task_name(task) @@ -97,31 +105,29 @@ def check_if_task_and_model_is_supported(task, model_name): def check_if_task_and_model_is_valid(task, model_name): task_name = get_task_name(task) valid_task_models = _get_hf_models_by_type(None, task_name) - assert model_name in valid_task_models, f"{task_name} only supports {valid_task_models}" - - -def get_model_path(): - aml_model_dir = os.getenv('AZUREML_MODEL_DIR') - if aml_model_dir is not None: - return aml_model_dir - - mii_model_dir = os.getenv('MII_MODEL_DIR') - - if mii_model_dir is not None: - return mii_model_dir - - assert False, "MII_MODEL_DIR must be set. Current value is None" + assert ( + model_name in valid_task_models + ), f"{task_name} only supports {valid_task_models}" + + +def full_model_path(model_path): + aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None) + if aml_model_dir: + # (potentially) append relative model_path w. aml path + assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path" + if model_path: + assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path" + return os.path.join(aml_model_dir, model_path) + else: + return aml_model_dir + elif model_path: + return model_path + else: + return mii.constants.MII_MODEL_PATH_DEFAULT def is_aml(): - return os.getenv('AZUREML_MODEL_DIR') is not None - - -def set_model_path(model_path): - if model_path is None: - model_path = MII_MODEL_PATH_DEFAULT - os.makedirs(model_path, exist_ok=True) - os.environ['MII_MODEL_DIR'] = str(Path(model_path).resolve()) + return os.getenv("AZUREML_MODEL_DIR") is not None def mii_cache_path(): @@ -133,7 +139,7 @@ def mii_cache_path(): def import_score_file(deployment_name): spec = importlib.util.spec_from_file_location( - 'score', + "score", os.path.join(mii_cache_path(), deployment_name, "score.py")) @@ -142,45 +148,6 @@ def import_score_file(deployment_name): return score -def generated_score_path(deployment_name): - score_path = os.path.join(mii_cache_path(), deployment_name) - if not os.path.isdir(score_path): - os.makedirs(score_path) - return os.path.join(score_path, "score.py") - - -def debug_score_preamble(): - preamble = "" - debug_mode_enabled = int(os.environ.get(MII_DEBUG_MODE, MII_DEBUG_MODE_DEFAULT)) - if not debug_mode_enabled: - return preamble - - deploy_key = os.environ.get(MII_DEBUG_DEPLOY_KEY) - debug_branch = os.environ.get(MII_DEBUG_BRANCH, MII_DEBUG_BRANCH_DEFAULT) - key_path = "/tmp/mii_deploy_key" - - preamble = f""" -import subprocess, os, sys -deploy_key = '''{deploy_key}''' -with open('{key_path}', 'w') as fd: - fd.write(deploy_key) - fd.write("\\n") -subprocess.run(['chmod', '600', '{key_path}']) -env = os.environ.copy() -env["GIT_SSH_COMMAND"]="ssh -i {key_path} -o StrictHostKeyChecking=no" -install_cmd = "-m pip install git+ssh://git@github.com/microsoft/DeepSpeed-MII.git@{debug_branch}" -subprocess.run([sys.executable] + install_cmd.split(" "), env=env) - -""" - return preamble - - -def setup_task(): - # The second value returned here is hard-coded for now as we want to always - # run grpc server on AML, originally was "not is_aml()" - return get_model_path(), True, is_aml() - - dtype_proto_field = { str: "svalue", int: "ivalue", @@ -191,13 +158,24 @@ def setup_task(): def kwarg_dict_to_proto(kwarg_dict): def get_proto_value(value): - proto_value = mii.modelresponse_pb2.Value() + proto_value = mii.grpc_related.proto.modelresponse_pb2.Value() setattr(proto_value, dtype_proto_field[type(value)], value) return proto_value return {k: get_proto_value(v) for k, v in kwarg_dict.items()} +def extract_query_dict(task, request_dict): + required_keys = REQUIRED_KEYS_PER_TASK[task] + query_dict = {} + for key in required_keys: + value = request_dict.pop(key, None) + if value is None: + raise ValueError("Request for task: {task} is missing required key: {key}.") + query_dict[key] = value + return query_dict + + log_levels = { "debug": logging.DEBUG, "info": logging.INFO, diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index 9654768e..96cc3c24 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -1,5 +1,4 @@ import pytest -import functools from types import SimpleNamespace import mii @@ -69,7 +68,7 @@ def deployment_config(task_name: str, model=model_name, deployment_type=mii.DeploymentType.LOCAL, deployment_name=model_name + "_deployment", - local_model_path=".cache/models/" + model_name, + model_path=".cache/models/" + model_name, mii_config=mii_configs, enable_deepspeed=enable_deepspeed, enable_zero=enable_zero,