Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed context in pytorch engine to support torchrun #2615

Merged
merged 2 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.distributed as dist

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


def get_world_rank():
"""get current world size and rank."""
world_size = 1
rank = 0
from lmdeploy.pytorch.distributed import get_world_rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a short script to show how to use torchrun with lmdeploy to test this pr?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import os
import time
import argparse

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from tqdm import tqdm
from lmdeploy import pipeline, PytorchEngineConfig, GenerationConfig, ChatTemplateConfig, VisionConfig
from lmdeploy.vl import load_image as load_image

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')

    num_gpus = torch.cuda.device_count()
    if torch.__version__ > '1.10':
        local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank % num_gpus)

    dist.init_process_group(
        backend=backend,
    )
    rank = dist.get_rank()
    num_gpus = dist.get_world_size()
    return num_gpus, rank

def parse_config():
    parser = argparse.ArgumentParser(description='arg parser')
    parser.add_argument('--model_path', type=str, default=None, help='checkpoint to start from')
    parser.add_argument('--tcp_port', type=int, default=18888, help='tcp port for distrbuted training')
    parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = parse_config()
    num_gpus, rank = init_dist_pytorch(args.tcp_port, args.local_rank)

    pipe = pipeline(model_path=args.model_path, 
                    backend_config=PytorchEngineConfig(dtype='bfloat16', cache_max_entry_count=0.1,
                                                       max_batch_size=1),
                    vision_config=VisionConfig(max_batch_size=1), log_level='INFO',
                    # chat_template_config=ChatTemplateConfig(model_name='internvl2-internlm2')
                    )

    generation_config = GenerationConfig(max_new_tokens=4096, do_sample=False, temperature=0.0)

    iteration_num = 20

    input_text = "Explain the concept of artificial intelligence in simple terms."
    if num_gpus == 1:
        start_time = time.time()

        for _ in tqdm(range(iteration_num), ncols=140, desc=f"Single GPU"):
            output = pipe([input_text], gen_config=generation_config)

        print(f"Single GPU average inference time: {time.time()-start_time:.1f} seconds")

    else:
        dist.barrier()

        start_time = time.time()
        for _ in tqdm(range(iteration_num//num_gpus), ncols=140, desc=f"Multi GPU", disable=rank!=0):
            output = pipe([input_text], gen_config=generation_config)

        dist.barrier()
        if rank == 0:
            print(f"Multi-GPU average inference time: {time.time()-start_time:.1f} seconds")
torchrun --nproc_per_node=2 test.py \
    --model_path InternVL2-1B


if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank
from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


class TritonAttentionMetadata(AttentionMetadata):
Expand Down
66 changes: 66 additions & 0 deletions lmdeploy/pytorch/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from contextlib import contextmanager
from dataclasses import dataclass

from torch import distributed as dist


@dataclass
class DistContext:
rank: int = 0
world_size: int = 1
dist_group: dist.ProcessGroup = None


DefaultContext = DistContext()


class DistManager:
"""distributed context manager."""

def __init__(self):
self.t_local = threading.local()
self.t_local.device_context = DefaultContext

def current_context(self) -> DistContext:
"""get current context."""
return getattr(self.t_local, 'device_context', DefaultContext)

def set_context(self, context: DistContext):
"""set current context."""
self.t_local.device_context = context

@contextmanager
def context(self, context: DistContext):
"""context manager."""
origin_context = self.current_context()
self.set_context(context)
yield self
self.set_context(origin_context)


_DIST_MANAGER: DistManager = None


def get_dist_manager():
"""get device manager."""
global _DIST_MANAGER
if _DIST_MANAGER is None:
_DIST_MANAGER = DistManager()
return _DIST_MANAGER


def get_world_rank():
"""get distributed world size and rank."""
ctx = get_dist_manager().current_context()
world_size = ctx.world_size
rank = ctx.rank

return world_size, rank


def get_process_group():
"""get process group."""
ctx = get_dist_manager().current_context()
return ctx.dist_group
64 changes: 35 additions & 29 deletions lmdeploy/pytorch/engine/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..backends import get_backend
from ..config import BackendConfig, CacheConfig, ModelConfig
from ..devices import DeviceContext, get_device_manager
from ..distributed import DistContext, get_dist_manager, get_world_rank
from ..model_inputs import ModelInputs
from ..models.patch import (add_adapters, build_patched_model,
update_custom_module_map)
Expand Down Expand Up @@ -81,9 +82,7 @@ def __adjust_block_size():
# TODO: support kernel with both large head dim and large block size.
if model_config.k_head_dim >= 512 and cache_config.block_size > 32:
cache_config.block_size = 32
rank = 0
if dist.is_initialized():
rank = dist.get_rank()
_, rank = get_world_rank()
if rank == 0:
logger.warning(
f'Update `block_size={cache_config.block_size}`'
Expand Down Expand Up @@ -482,9 +481,11 @@ def _start_tp_process(proc_id: int,
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
dist_ctx = DistContext(rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
with get_device_manager().context(
device_context), torch.inference_mode():
with (get_dist_manager().context(dist_ctx),
get_device_manager().context(device_context),
torch.inference_mode()):
args = args or tuple()
kwargs = kwargs or dict()
func(rank, *args, **kwargs)
Expand Down Expand Up @@ -565,6 +566,7 @@ def __signal_term_handler(sig, frame):
self.world_size = world_size
self.backend_config = backend_config

self._dist_ctx = None
self.mp_bar = self.mp_ctx.Barrier(world_size)
self._start_sub_process(model_path,
model_config=model_config,
Expand Down Expand Up @@ -645,6 +647,8 @@ def _start_sub_process(self, model_path: str, model_config: ModelConfig,
rank=rank,
world_size=world_size,
timeout=timedelta(days=35600))
dist_ctx = DistContext(rank=rank, world_size=world_size)
self._dist_ctx = dist_ctx
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
Expand All @@ -665,16 +669,17 @@ def _build_model(
world_size: int,
):
"""build model."""
rank = 0
model, cache_engine, cache_config = _tp_build_model(
rank,
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
)
with get_dist_manager().context(self._dist_ctx):
rank = 0
model, cache_engine, cache_config = _tp_build_model(
rank,
model_path=model_path,
model_config=model_config,
cache_config=cache_config,
backend_config=backend_config,
adapters=adapters,
world_size=world_size,
)

return model, cache_engine, cache_config

Expand All @@ -686,20 +691,21 @@ def get_block_numel(self):
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""forward impl."""
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
world_size=1,
stream=self.stream,
)
with get_dist_manager().context(self._dist_ctx):
self.mp_bar.wait()
rank = 0
_broadcast_inputs(rank, [inputs, swap_in_map, swap_out_map],
self.stream)
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
world_size=1,
stream=self.stream,
)
return output

def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -16,19 +17,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class VisionExpertAttention(nn.Module):
"""Rewrite module of VisionExpertAttention."""

Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -17,19 +18,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class DeepseekAttention(nn.Module):
"""Rewrite module of MistralAttention."""

Expand Down
18 changes: 2 additions & 16 deletions lmdeploy/pytorch/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -25,19 +26,6 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class DeepseekV2BMM(nn.Module):
"""wrapped bmm."""

Expand Down Expand Up @@ -240,9 +228,7 @@ def forward(
attn_metadata: Any = None,
):
"""Rewrite of LlamaAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
world_size, _ = get_world_rank()
num_heads = self.num_heads // world_size
nope_size = self.kv_lora_rank
q_len = hidden_states.size(1)
Expand Down
6 changes: 2 additions & 4 deletions lmdeploy/pytorch/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import Any, Iterable, List, Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (Attention, RMSNorm, RopeType, SiluAndMul,
build_rotary_embedding)
Expand Down Expand Up @@ -118,9 +118,7 @@ def forward(
attn_metadata: Any = None,
):
"""Rewrite of LlamaAttention.forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
world_size, _ = get_world_rank()
num_heads = self.num_heads // world_size
bsz, q_len, _ = hidden_states.size()

Expand Down
14 changes: 1 addition & 13 deletions lmdeploy/pytorch/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import nn
from transformers.configuration_utils import PretrainedConfig

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
SiluAndMul, build_rotary_embedding)
Expand All @@ -18,19 +19,6 @@
from .utils.cudagraph import CudaGraphMixin


def get_world_rank():
"""get current world size and rank."""
import torch.distributed as dist
world_size = 1
rank = 0

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()

return world_size, rank


class Qwen2MoeAttention(nn.Module):
"""Rewrite module of Qwen2MoeAttention."""

Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank

from ..backends import OpType, get_backend
from ..backends.attention import AttentionMetadata
from .utils import get_distribute_size, get_world_rank
from .utils import get_distribute_size


class Attention(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank
from lmdeploy.pytorch.weight_loader.model_weight_loader import \
default_weight_loader
from lmdeploy.utils import get_logger

from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
from .utils import div_up, get_distribute_size, get_world_rank
from .utils import div_up, get_distribute_size

logger = get_logger('lmdeploy')

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/nn/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import torch.distributed as dist
from torch import nn

from lmdeploy.pytorch.distributed import get_world_rank

from ..backends import OpType, get_backend
from .utils import get_world_rank


class SoftmaxTopK(nn.Module):
Expand Down
Loading
Loading