Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/formatting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:

- name: Install MII
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git@loadams/try-bump-pydantic
pip install .[dev]

- name: Formatting checks
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nv-torch-latest-v100.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:

- name: Install dependencies
run: |
pip install git+https://github.com/microsoft/DeepSpeed.git
pip install git+https://github.com/microsoft/DeepSpeed.git@loadams/try-bump-pydantic
pip install git+https://github.com/huggingface/transformers.git
pip install -U accelerate
ds_report
Expand Down
87 changes: 40 additions & 47 deletions mii/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

# DeepSpeed Team
import torch
from typing import Union, List
from typing import Any, Union, List, Optional
from enum import Enum
from pydantic import BaseModel, validator, root_validator
from pydantic import field_validator, model_validator, ConfigDict, BaseModel, FieldValidationInfo

from deepspeed.launcher.runner import DLTS_HOSTFILE

Expand Down Expand Up @@ -44,10 +44,10 @@ class MIIConfig(BaseModel):
meta_tensor: bool = False
load_with_sys_mem: bool = False
enable_cuda_graph: bool = False
checkpoint_dict: Union[dict, None] = None
checkpoint_dict: Optional[dict] = None
deploy_rank: Union[int, List[int]] = -1
torch_dist_port: int = 29500
hf_auth_token: str = None
hf_auth_token: Optional[str] = None
replace_with_kernel_inject: bool = True
profile_model_time: bool = False
skip_model_check: bool = False
Expand All @@ -58,70 +58,63 @@ class MIIConfig(BaseModel):
hostfile: str = DLTS_HOSTFILE
trust_remote_code: bool = False

@validator("deploy_rank")
def deploy_valid(cls, field_value, values):
if "tensor_parallel" not in values:
@field_validator('checkpoint_dict')
@classmethod
def checkpoint_dict_valid(cls, v: Union[None, dict], info: FieldValidationInfo):
if v is None:
return v
if v.get('base_dir', ''):
raise ValueError(
"'tensor_parallel' must be defined in the pydantic model before 'deploy_rank'"
"please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'"
)
for k in ['checkpoints', 'parallelization', 'version', 'type']:
if not v.get(k, ''):
raise ValueError(f"Missing key={k} in checkpoint_dict")
return v

@model_validator(mode="after")
@classmethod
def deploy_rank_valid(cls, data: Any) -> Any:
# if deploy rank is not given, default to align with TP value
if field_value == -1:
field_value = list(range(values["tensor_parallel"]))
if data.deploy_rank == -1:
data.deploy_rank = list(range(data.tensor_parallel))

# ensure deploy rank type is always list for easier consumption later
if not isinstance(field_value, list):
field_value = [field_value]
if not isinstance(data.deploy_rank, list):
data.deploy_rank = [data.deploy_rank]

# number of ranks provided must be equal to TP size, DP is handled outside MII currently
assert values["tensor_parallel"] == len(field_value), \
f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}"
return field_value

@validator('checkpoint_dict')
def checkpoint_dict_valid(cls, value):
if value is None:
return value
if value.get('base_dir', ''):
raise ValueError(
"please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'"
)
for k in ['checkpoints', 'parallelization', 'version', 'type']:
if not value.get(k, ''):
raise ValueError(f"Missing key={k} in checkpoint_dict")
return value

@root_validator
def meta_tensor_or_sys_mem(cls, values):
if values.get("meta_tensor") and values.get("load_with_sys_mem"):
assert data.tensor_parallel == len(data.deploy_rank), \
f"{len(data.deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {data.tensor_parallel}"
return data

@model_validator(mode="after")
@classmethod
def meta_tensor_or_sys_mem(cls, data: Any) -> Any:
if data.meta_tensor and data.load_with_sys_mem:
raise ValueError(
"`meta_tensor` and `load_with_sys_mem` cannot be active at the same time."
)
return values
return data

class Config:
validate_all = True
validate_assignment = True
use_enum_values = True
extra = 'forbid'
json_encoders = {torch.dtype: lambda x: str(x)}
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(validate_default=True,
validate_assignment=True,
use_enum_values=True,
extra='forbid',
json_encoders={torch.dtype: lambda x: str(x)})


class ReplicaConfig(BaseModel):
hostname: str = ""
tensor_parallel_ports: List[int] = []
torch_dist_port: int = None
gpu_indices: List[int] = []

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)


class LoadBalancerConfig(BaseModel):
port: int = None
replica_configs: List[ReplicaConfig] = []

class Config:
validate_all = True
validate_assignment = True
model_config = ConfigDict(validate_default=True, validate_assignment=True)
2 changes: 1 addition & 1 deletion mii/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def config_to_b64_str(config):
# convert json str -> bytes
json_bytes = config.json().encode()
json_bytes = config.model_dump_json().encode()
# base64 encoded bytes
b64_config_bytes = base64.urlsafe_b64encode(json_bytes)
# bytes -> str
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ deepspeed>=0.7.6
Flask-RESTful
grpcio
grpcio-tools
pydantic
pydantic>=2.0.0
torch
transformers
Werkzeug