Skip to content

Commit

Permalink
Add distributed context in pytorch engine to support torchrun (#2615)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Oct 23, 2024
1 parent a50555b commit cca7d36
Show file tree
Hide file tree
Showing 13 changed files with 119 additions and 132 deletions.
15 changes: 2 additions & 13 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


def get_world_rank():
"""get current world size and rank."""
world_size = 1
rank = 0
from lmdeploy.pytorch.distributed import get_world_rank

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank
from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


class TritonAttentionMetadata(AttentionMetadata):
Expand Down
66 changes: 66 additions & 0 deletions lmdeploy/pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from contextlib import contextmanager
from dataclasses import dataclass

from torch import distributed as dist


@dataclass
class DistContext:
rank: int = 0
world_size: int = 1
dist_group: dist.ProcessGroup = None


DefaultContext = DistContext()


class DistManager:
"""distributed context manager."""

def __init__(self):
self.t_local = threading.local()
self.t_local.device_context = DefaultContext

def current_context(self) -> DistContext:
"""get current context."""
return getattr(self.t_local, 'device_context', DefaultContext)

def set_context(self, context: DistContext):
"""set current context."""
self.t_local.device_context = context

@contextmanager
def context(self, context: DistContext):
"""context manager."""
origin_context = self.current_context()
self.set_context(context)
yield self
self.set_context(origin_context)


_DIST_MANAGER: DistManager = None


def get_dist_manager():
"""get device manager."""
global _DIST_MANAGER
if _DIST_MANAGER is None:
_DIST_MANAGER = DistManager()
return _DIST_MANAGER


def get_world_rank():
"""get distributed world size and rank."""
ctx = get_dist_manager().current_context()
world_size = ctx.world_size
rank = ctx.rank

return world_size, rank


def get_process_group():
"""get process group."""
ctx = get_dist_manager().current_context()
return ctx.dist_group
64 changes: 35 additions & 29 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..backends import get_backend
from ..config import BackendConfig, CacheConfig, ModelConfig
from ..devices import DeviceContext, get_device_manager
from ..distributed import DistContext, get_dist_manager, get_world_rank
from ..model_inputs import ModelInputs
from ..models.patch import (add_adapters, build_patched_model,
update_custom_module_map)
Expand Down Expand Up @@ -81,9 +82,7 @@ def __adjust_block_size():
# TODO: support kernel with both large head dim and large block size.
if model_config.k_head_dim >= 512 and cache_config.block_size > 32:
cache_config.block_size = 32
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
_, rank = get_world_rank()
if rank == 0:
logger.warning(
f'Update `block_size={cache_config.block_size}`'
Expand Down Expand Up @@ -482,9 +481,11 @@ def _start_tp_process(proc_id: int,
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_device_manager().context(
device_context), torch.inference_mode():
with (get_dist_manager().context(dist_ctx),
get_device_manager().context(device_context),
torch.inference_mode()):
args = args or tuple()
kwargs = kwargs or dict()
func(rank, *args, **kwargs)
Expand Down Expand Up @@ -565,6 +566,7 @@ def __signal_term_handler(sig, frame):
self.world_size = world_size
self.backend_config = backend_config

self._dist_ctx = None
self.mp_bar = self.mp_ctx.Barrier(world_size)
self._start_sub_process(model_path,
model_config=model_config,
Expand Down Expand Up @@ -645,6 +647,8 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
Expand All @@ -665,16 +669,17 @@ def _build_model(
world_size: int,
):
"""build model."""
rank = 0
model, cache_engine, cache_config = _tp_build_model(
rank,
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
)
with get_dist_manager().context(self._dist_ctx):
rank = 0
model, cache_engine, cache_config = _tp_build_model(
rank,
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
)

return model, cache_engine, cache_config

Expand All @@ -686,20 +691,21 @@ def get_block_numel(self):
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""forward impl."""
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
world_size=1,
stream=self.stream,
)
with get_dist_manager().context(self._dist_ctx):
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
world_size=1,
stream=self.stream,
)
return output

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -16,19 +17,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class VisionExpertAttention(nn.Module):
"""Rewrite module of VisionExpertAttention."""

Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -17,19 +18,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class DeepseekAttention(nn.Module):
"""Rewrite module of MistralAttention."""

Expand Down
18 changes: 2 additions & 16 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -25,19 +26,6 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class DeepseekV2BMM(nn.Module):
"""wrapped bmm."""

Expand Down Expand Up @@ -240,9 +228,7 @@ def forward(
attn_metadata: Any = None,
):
"""Rewrite of LlamaAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
world_size, _ = get_world_rank()
num_heads = self.num_heads // world_size
nope_size = self.kv_lora_rank
q_len = hidden_states.size(1)
Expand Down
6 changes: 2 additions & 4 deletions lmdeploy/pytorch/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import Any, Iterable, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (Attention, RMSNorm, RopeType, SiluAndMul,
build_rotary_embedding)
Expand Down Expand Up @@ -118,9 +118,7 @@ def forward(
attn_metadata: Any = None,
):
"""Rewrite of LlamaAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
world_size, _ = get_world_rank()
num_heads = self.num_heads // world_size
bsz, q_len, _ = hidden_states.size()

Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -18,19 +19,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class Qwen2MoeAttention(nn.Module):
"""Rewrite module of Qwen2MoeAttention."""

Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank

from ..backends import OpType, get_backend
from ..backends.attention import AttentionMetadata
from .utils import get_distribute_size, get_world_rank
from .utils import get_distribute_size


class Attention(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.weight_loader.model_weight_loader import \
default_weight_loader
from lmdeploy.utils import get_logger

from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
from .utils import div_up, get_distribute_size, get_world_rank
from .utils import div_up, get_distribute_size

logger = get_logger('lmdeploy')

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank

from ..backends import OpType, get_backend
from .utils import get_world_rank


class SoftmaxTopK(nn.Module):
Expand Down
Loading

0 comments on commit cca7d36

Please sign in to comment.