Skip to content
6 changes: 6 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ class PytorchEngineConfig:
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
attn_tp_size (int): tp size for attention, only works for dp>1
mlp_tp_size (int): tp size for mlp, only works for dp>1
moe_tp_size (int): tp size for moe, only works for dp>1
cache_max_entry_count (float): the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
Expand Down Expand Up @@ -350,6 +353,9 @@ class PytorchEngineConfig:
ep: int = 1
session_len: int = None
max_batch_size: int = None
attn_tp_size: int = None
mlp_tp_size: int = None
moe_tp_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/backends/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def update_weights(self,
return qweight, scales, qzeros, bias

@abstractmethod
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearBlockedF8Impl(ABC):
Expand All @@ -19,6 +20,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_tp_world_rank()
world_size, rank = get_tp_world_rank('attn')
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_features = scales.size(1)
out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
20 changes: 6 additions & 14 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
logger = get_logger('lmdeploy')


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""Triton linear blocked f8 implementation."""

Expand All @@ -37,6 +28,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -52,7 +44,7 @@ def forward(self,

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
return out
Expand Down Expand Up @@ -117,6 +109,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -128,12 +121,11 @@ def forward(self,
out = out[:x.size(0)]
if bias is not None:
out += bias
out = out.unflatten(0, x_shape[:-1])

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
dist.all_reduce(out, group=group)
return out
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def update_inputs(self, inputs):
meta = self.get_meta()
padding_batch_size = meta.padding_batch_size
tp_size = self._get_capture_tokens(padding_batch_size)
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
Expand Down
Loading