diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 6aa26187..4869ed03 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -29,6 +29,7 @@ jobs: - name: Install MII run: | + pip install git+https://github.com/microsoft/DeepSpeed.git@loadams/try-bump-pydantic pip install .[dev] - name: Formatting checks diff --git a/.github/workflows/nv-torch-latest-v100.yaml b/.github/workflows/nv-torch-latest-v100.yaml index 2c10447a..9c5f6357 100644 --- a/.github/workflows/nv-torch-latest-v100.yaml +++ b/.github/workflows/nv-torch-latest-v100.yaml @@ -33,7 +33,7 @@ jobs: - name: Install dependencies run: | - pip install git+https://github.com/microsoft/DeepSpeed.git + pip install git+https://github.com/microsoft/DeepSpeed.git@loadams/try-bump-pydantic pip install git+https://github.com/huggingface/transformers.git pip install -U accelerate ds_report diff --git a/mii/config.py b/mii/config.py index 2714cb40..d9f1e70d 100644 --- a/mii/config.py +++ b/mii/config.py @@ -3,9 +3,9 @@ # DeepSpeed Team import torch -from typing import Union, List +from typing import Any, Union, List, Optional from enum import Enum -from pydantic import BaseModel, validator, root_validator +from pydantic import field_validator, model_validator, ConfigDict, BaseModel, FieldValidationInfo from deepspeed.launcher.runner import DLTS_HOSTFILE @@ -44,10 +44,10 @@ class MIIConfig(BaseModel): meta_tensor: bool = False load_with_sys_mem: bool = False enable_cuda_graph: bool = False - checkpoint_dict: Union[dict, None] = None + checkpoint_dict: Optional[dict] = None deploy_rank: Union[int, List[int]] = -1 torch_dist_port: int = 29500 - hf_auth_token: str = None + hf_auth_token: Optional[str] = None replace_with_kernel_inject: bool = True profile_model_time: bool = False skip_model_check: bool = False @@ -58,53 +58,52 @@ class MIIConfig(BaseModel): hostfile: str = DLTS_HOSTFILE trust_remote_code: bool = False - @validator("deploy_rank") - def deploy_valid(cls, field_value, values): - if "tensor_parallel" not in values: + @field_validator('checkpoint_dict') + @classmethod + def checkpoint_dict_valid(cls, v: Union[None, dict], info: FieldValidationInfo): + if v is None: + return v + if v.get('base_dir', ''): raise ValueError( - "'tensor_parallel' must be defined in the pydantic model before 'deploy_rank'" + "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" ) + for k in ['checkpoints', 'parallelization', 'version', 'type']: + if not v.get(k, ''): + raise ValueError(f"Missing key={k} in checkpoint_dict") + return v + @model_validator(mode="after") + @classmethod + def deploy_rank_valid(cls, data: Any) -> Any: # if deploy rank is not given, default to align with TP value - if field_value == -1: - field_value = list(range(values["tensor_parallel"])) + if data.deploy_rank == -1: + data.deploy_rank = list(range(data.tensor_parallel)) # ensure deploy rank type is always list for easier consumption later - if not isinstance(field_value, list): - field_value = [field_value] + if not isinstance(data.deploy_rank, list): + data.deploy_rank = [data.deploy_rank] # number of ranks provided must be equal to TP size, DP is handled outside MII currently - assert values["tensor_parallel"] == len(field_value), \ - f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}" - return field_value - - @validator('checkpoint_dict') - def checkpoint_dict_valid(cls, value): - if value is None: - return value - if value.get('base_dir', ''): - raise ValueError( - "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" - ) - for k in ['checkpoints', 'parallelization', 'version', 'type']: - if not value.get(k, ''): - raise ValueError(f"Missing key={k} in checkpoint_dict") - return value - - @root_validator - def meta_tensor_or_sys_mem(cls, values): - if values.get("meta_tensor") and values.get("load_with_sys_mem"): + assert data.tensor_parallel == len(data.deploy_rank), \ + f"{len(data.deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {data.tensor_parallel}" + return data + + @model_validator(mode="after") + @classmethod + def meta_tensor_or_sys_mem(cls, data: Any) -> Any: + if data.meta_tensor and data.load_with_sys_mem: raise ValueError( "`meta_tensor` and `load_with_sys_mem` cannot be active at the same time." ) - return values + return data - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' - json_encoders = {torch.dtype: lambda x: str(x)} + # TODO[pydantic]: The following keys were removed: `json_encoders`. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. + model_config = ConfigDict(validate_default=True, + validate_assignment=True, + use_enum_values=True, + extra='forbid', + json_encoders={torch.dtype: lambda x: str(x)}) class ReplicaConfig(BaseModel): @@ -112,16 +111,10 @@ class ReplicaConfig(BaseModel): tensor_parallel_ports: List[int] = [] torch_dist_port: int = None gpu_indices: List[int] = [] - - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) class LoadBalancerConfig(BaseModel): port: int = None replica_configs: List[ReplicaConfig] = [] - - class Config: - validate_all = True - validate_assignment = True + model_config = ConfigDict(validate_default=True, validate_assignment=True) diff --git a/mii/server.py b/mii/server.py index 0825e060..d25f7540 100644 --- a/mii/server.py +++ b/mii/server.py @@ -19,7 +19,7 @@ def config_to_b64_str(config): # convert json str -> bytes - json_bytes = config.json().encode() + json_bytes = config.model_dump_json().encode() # base64 encoded bytes b64_config_bytes = base64.urlsafe_b64encode(json_bytes) # bytes -> str diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 243d5aee..5da93d1c 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,7 +3,7 @@ deepspeed>=0.7.6 Flask-RESTful grpcio grpcio-tools -pydantic +pydantic>=2.0.0 torch transformers Werkzeug