diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 3519098f3c..e272af7f41 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -108,5 +108,8 @@ def get_features(): env_var_type=boolean), Value('moe_chunk', "", env_var='VLLM_MOE_CHUNK', env_var_type=list_of(int)), Value('moe_token_boundary', "", env_var='VLLM_MOE_TOKEN_BOUNDARY', env_var_type=list_of(int)), + Value('use_dispatch_fn', + All(VersionRange(">=1.24.0.460"), MinPackageVersion("neural_compressor_pt", "3.6")), + env_var_type=boolean), ] return split_values_and_flags(features) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 51b42d6b55..a279dbd45e 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -492,7 +492,12 @@ def forward(self, state, expert_id, w): class VllmMixtureOfExpertsOpBase(torch.nn.Module): - def __init__(self, global_num_experts: int, num_total_experts, experts_min: int = 0, experts_max: int = 8): + def __init__(self, + global_num_experts: int, + num_total_experts: int, + experts_min: int = 0, + experts_max: int = 8, + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__() self.experts_min = experts_min self.experts_max = experts_max @@ -516,6 +521,11 @@ def __init__(self, global_num_experts: int, num_total_experts, experts_min: int assert len(self.chunk_size_list) == len( self.token_boundary_list), (f"chunk_size_list({len(self.chunk_size_list)}) and " f"token_boundary_list({len(self.token_boundary_list)}) must be the same length") + """ + dispatch_func is used to dispatch quantized tokens under data parallel + acenario. + """ + self.dispatch_func = dispatch_fn def _get_extra_kwargs(self, tokens_num: int): if self.chunk_size_list: @@ -532,11 +542,20 @@ def _get_extra_kwargs(self, tokens_num: int): kwargs = {} return kwargs + def _get_dispatch_func(self): + fn = self.dispatch_func + return fn + class VllmMixtureOfExpertsOp(VllmMixtureOfExpertsOpBase): - def __init__(self, global_num_experts: int, num_total_experts, experts_min: int = 0, experts_max: int = 8): - super().__init__(global_num_experts, num_total_experts, experts_min, experts_max) + def __init__(self, + global_num_experts: int, + num_total_experts: int, + experts_min: int = 0, + experts_max: int = 8, + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) @@ -979,8 +998,13 @@ def get_dequant_weights_func(self, ) -> Optional[Callable[[torch.nn.Module], tor class VllmMixtureOfExpertsOpFP8(VllmMixtureOfExpertsOpBase): - def __init__(self, global_num_experts: int, num_experts: int, experts_min: int = 0, experts_max: int = 8): - super().__init__(global_num_experts, num_experts, experts_min, experts_max) + def __init__(self, + global_num_experts: int, + num_experts: int, + experts_min: int = 0, + experts_max: int = 8, + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + super().__init__(global_num_experts, num_experts, experts_min, experts_max, dispatch_fn) self.w13_list = torch.nn.ModuleList( [MoeFP8Matmul(quant_method=FusedMoeWeightScaleSupported.BLOCK.value) for _ in range(num_experts)]) self.w2_list = torch.nn.ModuleList( @@ -1039,8 +1063,13 @@ def forward( class VllmMixtureOfExpertsOpFP8PerChannel(VllmMixtureOfExpertsOpBase): - def __init__(self, global_num_experts: int, num_experts: int, experts_min: int = 0, experts_max: int = 8): - super().__init__(global_num_experts, num_experts, experts_min, experts_max) + def __init__(self, + global_num_experts: int, + num_experts: int, + experts_min: int = 0, + experts_max: int = 8, + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + super().__init__(global_num_experts, num_experts, experts_min, experts_max, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeFP8Matmul() for _ in range(num_experts)]) self.w2_list = torch.nn.ModuleList([MoeFP8Matmul() for _ in range(num_experts)]) self.w13_input_scale = None diff --git a/vllm_gaudi/ops/hpu_fp8.py b/vllm_gaudi/ops/hpu_fp8.py index 973ff49e11..30294e79ec 100644 --- a/vllm_gaudi/ops/hpu_fp8.py +++ b/vllm_gaudi/ops/hpu_fp8.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Optional import torch @@ -10,7 +11,9 @@ Fp8Config) import vllm_gaudi.extension.ops as hpu_ops from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOpFP8PerChannel, VllmMixtureOfExpertsOpFP8) -from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata +from vllm_gaudi.extension.runtime import get_config +from vllm_gaudi.utils import has_quant_config +from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata class Fp8LinearMethod(OrigFp8LinearMethod): @@ -103,6 +106,8 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): # disable DeepGemm support. self.allow_deep_gemm = False + self.use_dispatch_fn = get_config().use_dispatch_fn + def create_weights(self, *args, **kwargs) -> None: if hpu_ops.is_hpu_gaudi2: kwargs['weight_loader'] = hpu_ops.gaudi_weight_wrapper(kwargs.get('weight_loader')) @@ -114,12 +119,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ep_shift = layer.ep_rank * num_experts experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 + if layer.dp_size > 1 and self.use_dispatch_fn: + dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.is_sequence_parallel) + else: + dispatch_fn = None + if self.block_quant and not envs.VLLM_HPU_FORCE_CHANNEL_FP8: layer.moe_op = VllmMixtureOfExpertsOpFP8( layer.global_num_experts, num_experts, experts_min, experts_max, + dispatch_fn, ) else: layer.moe_op = VllmMixtureOfExpertsOpFP8PerChannel( @@ -127,6 +138,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts, experts_min, experts_max, + dispatch_fn, ) if self.block_quant: layer = hpu_ops.fp8_block_moe_prepare_weights(layer, envs.VLLM_HPU_FORCE_CHANNEL_FP8) @@ -152,12 +164,14 @@ def apply( topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) - topk_ids = topk_ids.to(torch.int64) - topk_weights = topk_weights.to(x.dtype) + if not layer.use_grouped_topk: + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) if layer.dp_size > 1: - hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp - x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn): + hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp + x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) @@ -165,11 +179,8 @@ def apply( topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) - topk_ids = topk_ids.view(*x.shape[:-1], -1) - topk_weights = topk_weights.view(*x.shape[:-1], -1) - if not layer.use_grouped_topk: - topk_ids = topk_ids.to(torch.int64) - topk_weights = topk_weights.to(x.dtype) + topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) + topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) output = layer.moe_op( x, @@ -178,7 +189,7 @@ def apply( permuted_weights=True, activation=layer.activation, ) - return output.view(*(x.size(0), *input_shape[1:])) + return output.view(*(output.size(0), *input_shape[1:])) fp8.Fp8LinearMethod = Fp8LinearMethod diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index ba0b8d96be..42cdad1526 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Union import torch @@ -5,7 +6,9 @@ from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) from vllm_gaudi.extension.ops import (VllmMixtureOfExpertsOp) -from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_tensor, get_hpu_dp_metadata +from vllm_gaudi.extension.runtime import get_config +from vllm_gaudi.utils import has_quant_config +from vllm_gaudi.v1.worker.hpu_dp_utils import dispatch_hidden_states, dispatch_tensor, get_hpu_dp_metadata @UnquantizedFusedMoEMethod.register_oot @@ -14,6 +17,7 @@ class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() def process_weights_after_loading(self, layer: torch.nn.Module) -> None: @@ -23,11 +27,18 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ep_shift = layer.ep_rank * num_experts experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 + + if layer.dp_size > 1 and self.use_dispatch_fn: + dispatch_fn = partial(dispatch_hidden_states, is_sequence_parallel=layer.is_sequence_parallel) + else: + dispatch_fn = None + layer.moe_op = VllmMixtureOfExpertsOp( layer.global_num_experts, num_experts, experts_min, experts_max, + dispatch_fn, ) for expert_id in range(layer.local_num_experts): @@ -53,12 +64,14 @@ def forward_oot( topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) - topk_ids = topk_ids.to(torch.int64) - topk_weights = topk_weights.to(x.dtype) + if not layer.use_grouped_topk: + topk_ids = topk_ids.to(torch.int64) + topk_weights = topk_weights.to(x.dtype) if layer.dp_size > 1: - hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp - x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) + if not (has_quant_config(layer.vllm_config.model_config) and self.use_dispatch_fn): + hidden_states_across_dp = get_hpu_dp_metadata().hidden_states_across_dp + x = dispatch_tensor(x, hidden_states_across_dp, layer.is_sequence_parallel) topk_ids_across_dp = get_hpu_dp_metadata().topk_ids_across_dp topk_ids = dispatch_tensor(topk_ids, topk_ids_across_dp, layer.is_sequence_parallel) @@ -66,19 +79,17 @@ def forward_oot( topk_weights_across_dp = get_hpu_dp_metadata().topk_weights_across_dp topk_weights = dispatch_tensor(topk_weights, topk_weights_across_dp, layer.is_sequence_parallel) - topk_ids = topk_ids.view(*x.shape[:-1], -1) - topk_weights = topk_weights.view(*x.shape[:-1], -1) - if not layer.use_grouped_topk: - topk_ids = topk_ids.to(torch.int64) - topk_weights = topk_weights.to(x.dtype) + topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) + topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) - return layer.moe_op( + output = layer.moe_op( x, topk_ids, topk_weights, permuted_weights=True, activation=layer.activation, - ).view(*(x.size(0), *input_shape[1:])) + ) + return output.view(*(output.size(0), *input_shape[1:])) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: diff --git a/vllm_gaudi/utils.py b/vllm_gaudi/utils.py index 82647e5a0a..783dd2e830 100644 --- a/vllm_gaudi/utils.py +++ b/vllm_gaudi/utils.py @@ -1,5 +1,6 @@ from functools import cache import os +from vllm.config import ModelConfig from vllm.utils.torch_utils import make_tensor_with_pad, TORCH_DTYPE_TO_NUMPY_DTYPE from vllm_gaudi.extension.runtime import get_config from typing import (Any, Optional, TypeVar, Union) @@ -29,6 +30,10 @@ def hpu_backend_string(): return backend_string +def has_quant_config(model_config: ModelConfig) -> bool: + return model_config.quantization == "inc" or os.getenv("QUANT_CONFIG", None) is not None + + def async_h2d_copy(source, dest_tensor=None, dtype=None, device='hpu'): """ Asynchronously transfer data from host to device. diff --git a/vllm_gaudi/v1/worker/hpu_dp_utils.py b/vllm_gaudi/v1/worker/hpu_dp_utils.py index 3fec00862b..b71ad490e3 100644 --- a/vllm_gaudi/v1/worker/hpu_dp_utils.py +++ b/vllm_gaudi/v1/worker/hpu_dp_utils.py @@ -5,6 +5,8 @@ from typing import Optional from vllm.distributed import get_dp_group, get_ep_group from vllm.platforms import current_platform +from vllm_gaudi.extension.runtime import get_config +from vllm_gaudi.utils import has_quant_config import habana_frameworks.torch as htorch @@ -33,9 +35,11 @@ def make( assert num_experts_per_tok > 0, ( "num_experts_per_tok must be greater than 0 in model config. Please check the model config.") + is_quant_with_inc = has_quant_config(vllm_config.model_config) and get_config().use_dispatch_fn + hidden_states_dtype = (torch.float8_e4m3fn if is_quant_with_inc else dtype) hidden_states_across_dp = torch.empty( (num_tokens_across_dp, hidden_size), - dtype=dtype, + dtype=hidden_states_dtype, device=device, ) topk_ids_across_dp = torch.empty( @@ -115,3 +119,9 @@ def dispatch_tensor(input, output: torch.Tensor | None = None, is_sequence_paral output, input, group=get_ep_group().device_group if is_sequence_parallel else get_dp_group().device_group) return output + + +def dispatch_hidden_states(input, is_sequence_parallel): + dp_metadata = get_hpu_dp_metadata() + hidden_states_across_dp = dp_metadata.hidden_states_across_dp if dp_metadata is not None else None + return dispatch_tensor(input, hidden_states_across_dp, is_sequence_parallel) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1dbc3dd23d..102bcd4568 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -3769,6 +3769,8 @@ def load_model(self) -> None: disable_mark_scales_as_const = os.getenv("VLLM_DISABLE_MARK_SCALES_AS_CONST", "false") in ("1", "true") self._inc_preprocess() if config.measure: + assert self.parallel_config.data_parallel_size == 1, \ + "Data parallelism is not supported during the calibration stage." self.model = prepare(self.model, config) elif config.quantize: self.model = convert(self.model, config)