From bd6520e8ba9c2216defb75caf60192e42d562c36 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 31 Jul 2024 10:41:39 +0300 Subject: [PATCH 1/2] Revert "Tensor parallel distributed strategy without using deepspeed (#280)" This reverts commit c6e5f9cc1164b47bb14a54f5d5f9216b8d12590f. --- examples/text-generation/run_generation.py | 7 - examples/text-generation/utils.py | 99 +--- optimum/habana/distributed/__init__.py | 26 - optimum/habana/distributed/serialization.py | 489 ------------------ optimum/habana/distributed/strategy.py | 108 ---- optimum/habana/distributed/tp.py | 84 --- optimum/habana/distributed/tp_wrapping.py | 33 -- .../models/llama/modeling_llama.py | 154 +----- 8 files changed, 4 insertions(+), 996 deletions(-) delete mode 100644 optimum/habana/distributed/serialization.py delete mode 100644 optimum/habana/distributed/strategy.py delete mode 100644 optimum/habana/distributed/tp.py delete mode 100644 optimum/habana/distributed/tp_wrapping.py diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ab1f9615b3..6503cb7003 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -328,13 +328,6 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Run the inference with dataset for specified --n_iterations(default:5)", ) - parser.add_argument( - "--distributed_strategy", - type=str, - choices=["tp", "none"], # Add other strategies as needed - default="none", - help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", - ) parser.add_argument( "--load_cp", action="store_true", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 0a860f9bc9..f8e916e0a1 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -293,102 +293,6 @@ def setup_model(args, model_dtype, model_kwargs, logger): # assistant_model = get_torch_compiled_model(assistant_model) return model, assistant_model -def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): - - from optimum.habana.distributed import serialization - from typing import Any, MutableMapping - - from optimum.habana.distributed import tp_wrapping - from optimum.habana.distributed.strategy import DistributedStrategy - from torch import nn - - class TensorParallelStrategy(DistributedStrategy): - def __init__(self, group=None, from_meta=False): - super().__init__(from_meta) - assert torch.distributed.is_initialized(), "must initialize a process group" - self.group = group if group is not None else torch.distributed.GroupMember.WORLD - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - return tp_wrapping.apply_tp(module, self.group) - - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - return tp_wrapping.apply_tp(block, layer, self.group) - - def __getstate__(self): - state = self.__dict__.copy() - state['group'] = None # Remove ProcessGroup from state - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.group = None # Restore to default state or reinitialize - - logger.info("Multi-device run.") - - assert args.assistant_model is None, "Assistant model must be None" - - from torch import distributed as dist - if args.device == 'hpu': - import habana_frameworks.torch.distributed.hccl - dist.init_process_group(backend='hccl') - else: - dist.init_process_group() - - torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) - logger.info("Creating Model") - config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) - model_kwargs={} - model_kwargs["distributed_strategy"] = TensorParallelStrategy() - model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) - - initial_device = torch.device("cpu") - source="hf" - checkpoint_sharding=None - lazy_sd: MutableMapping[str, Any] = {} - logger.info("Loading Checkpoints") - lazy_sd = serialization.load_state_dict( - args.model_name_or_path, - source=source, - distributed_strategy=args.distributed_strategy, - checkpoint_sharding=None, - initial_device=initial_device, - rank=args.global_rank, - world_size=args.world_size, - ) - architecture="llama" - if len(lazy_sd): - serialization.load_state_dict_into_model( - model, - lazy_sd, - architecture, - source, - args.distributed_strategy, - checkpoint_sharding, - initial_device, - args.local_rank, - args.world_size, - ) - - if args.quant_config: - model = setup_quantization(model, args) - - model = model.eval().to(args.device) - - if args.use_hpu_graphs: - from habana_frameworks.torch.hpu import wrap_in_hpu_graph - - if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": - model = wrap_in_hpu_graph(model, hash_with_views=False) - else: - model = wrap_in_hpu_graph(model) - - if args.torch_compile and model.config.model_type == "llama": - model = get_torch_compiled_model(model) - - return model, args.assistant_model - def setup_distributed_model(args, model_dtype, model_kwargs, logger): import deepspeed @@ -655,8 +559,7 @@ def initialize_model(args, logger): model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed - else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp" - else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) + else setup_distributed_model(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 12edd6620e..2dedd7333d 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,28 +1,2 @@ from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients -import os -import torch - -def rank_and_world(group=None): - """ - Returns (rank, world_size) from the optionally-specified group, otherwise - from the default group, or if non-distributed just returns (0, 1) - """ - if torch.distributed.is_initialized() and group is None: - group = torch.distributed.GroupMember.WORLD - - if group is None: - world_size = 1 - rank = 0 - else: - world_size = group.size() - rank = group.rank() - - return rank, world_size - - -_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) - - -def local_rank(): - return _LOCAL_RANK diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py deleted file mode 100644 index c543ab20bd..0000000000 --- a/optimum/habana/distributed/serialization.py +++ /dev/null @@ -1,489 +0,0 @@ -import collections -import os -from collections import ChainMap -from collections.abc import Iterable -from pathlib import Path -from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union - -import torch - -from optimum.habana.distributed.tp import TPModule - - -__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} - - -def register_adapter( - architecture: str, - source: str, - adapter: Callable[[Mapping], Mapping], -): - """ - Registers a state dict adapter to be available to the (de) serialization - API. - - Args: - architecture: The name of the model architecture, e.g. 'llama' - source: A label representing the format of the weights to be converted. - E.g. 'hf' - adapter: the class of the adapter. The class must accept one constructor - parameter, which will be a state dict (`OrderedDict`) - """ - sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {} - if architecture in __adapters: - sources = __adapters[architecture] - - if source in sources: - raise KeyError( - f"Variant {source} already registered for architecture {architecture}" - ) - - sources[source] = adapter - __adapters[architecture] = sources - - -def list_sources(architecture: str): - """ - Lists available sources (attribute formats) of a model architecture. - E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf'] - Args: - architecture: one of the registered architectures returned by - `models.list_models()`. - """ - if architecture not in __adapters: - return [] - return list(__adapters[architecture].keys()) - - -def _get_adapter( - architecture: str, source: Optional[str] -) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: - if ( - source is None - or architecture not in __adapters - or source not in __adapters[architecture] - ): - # if no adapter is registered, assume the attributes are already in - # fms format. - # should we raise an error here instead? - return lambda x: x - else: - return __adapters[architecture][source] - - -def get_adapted( - architecture: str, source: Optional[str], state_dict: Mapping[str, Any] -) -> Mapping[str, Any]: - """ - Convert a state dict to FMS format, using an adapter specified by name. - - Args: - architecture: one of the architectures from `models.list_models()`. - E.g. llama. - source: A reference to an attribute format - state_dict: the model.state_dict() to be converted/adapted. - """ - # sometimes we only load onto rank 0 so may not have a state_dict here. - if not len(state_dict): - return state_dict - adapter = _get_adapter(architecture, source) - adapted = adapter(state_dict) - return adapted - - -# `models` imports each model class, causing models and adapters to be registered. -# down here to avoid circular dependencies. -# from fms import models - - -def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: - from safetensors import safe_open # type: ignore[import-untyped] - - with torch.no_grad(): - with safe_open( - file, framework="pt", device=str(device) - ) as model_weights: # type: ignore[attr-defined] - return model_weights.get_tensor(key) - - -class LazySafetensorsDict(collections.UserDict): - def set_lazy_tensor(self, key, file, device): - super().__setitem__(key, lambda: _get_safetensors_item(key, file, device)) - - def __getitem__(self, key): - lazy_tensor = super().__getitem__(key) - if callable(lazy_tensor): - lazy_tensor = lazy_tensor() - super().__setitem__(key, lazy_tensor) - return lazy_tensor - - -def load_state_dict( - model_path: Union[str, Path], - *, - source: Optional[str] = None, - distributed_strategy: Optional[str] = None, - checkpoint_sharding: Optional[str] = None, - initial_device: torch.device = torch.device("cpu"), - rank: int = 0, - world_size: int = 1, -) -> MutableMapping[str, Any]: - """ - Validates that the file(s) found at a checkpoint path are compatible with - the intended (possibly distributed) use-case, and returns a lazy loading - state dict if possible (some formats may not support that). - - If model_path is a directory, it'll try to load models based on the source - (e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't - been registered, it'll try .safetensors, .pth, and .bin. - - Args: - model_path: the path to find the weights. If not set, return None. - source: If the weights in the state dict didn't come from an FMS model, - `source` specifies which conversion function might be needed. - See `serialization.list_sources(architecture)` - distributed_strategy: the kind of possibly-distributed model in which we - intend to load these weights. E.g. tp, fsdp, None. Used for - validation. - checkpoint_sharding: the sharding format of the checkpoint. - E.g. layer, tp, fsdp. - initial_device: where the state dict will be loaded if not lazy. - If meta, return empty dict. - """ - if model_path is None or initial_device.type == "meta": - return {} - if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: - raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model") - if checkpoint_sharding == "tp" and distributed_strategy != "tp": - raise ValueError("TP checkpoints can only be loaded into a TP model") - - # Before creating the Path object, check if model_path has a glob pattern - if isinstance(model_path, str): - model_path, sep, glob_pattern = model_path.partition("*") - else: - sep = "" - glob_pattern = "" - glob_pattern = sep + glob_pattern - - model_path = Path(os.path.expanduser(model_path)) - - checkpoints = [] - - if model_path.is_dir(): - if glob_pattern != "": - glob_pattern_list = [glob_pattern] - elif source == "meta": - glob_pattern_list = ["*.pth", "*.safetensors"] - elif source == "hf": - glob_pattern_list = ["*.bin", "*.safetensors"] - else: - glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"] - for glob_pattern_possibility in glob_pattern_list: - file_list = list(model_path.glob(glob_pattern_possibility)) - if len(file_list) > 0: - checkpoints = sorted(file_list) - break - - if model_path.is_file(): - checkpoints = [model_path] - - # Check if we found some files - assert ( - len(checkpoints) > 0 - ), f"Can't find the requested checkpoint data at {model_path}" - - if checkpoint_sharding is not None and checkpoint_sharding != "layer": - assert world_size == len( - checkpoints - ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" - - checkpoints = [checkpoints[rank]] - - # if there's only one checkpoint for fsdp/hsdp, load it only into rank zero - # and it will be distributed by the FSDP `sync_module_states` parameter - if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}: - if rank == 0: - checkpoints = [checkpoints[0]] - else: - return {} - - checkpoint_sds = [] - if checkpoints[0].suffix == ".safetensors": - for ckp in checkpoints: - checkpoint_sds.append( - _load_safetensors_state_dict( - ckp, - initial_device, - ) - ) - else: - with torch.no_grad(): - checkpoint_sds = [ - torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints - ] - return ChainMap(*checkpoint_sds) - - -def _load_safetensors_state_dict( - checkpoint: Path, - device: torch.device, -): - sd = LazySafetensorsDict() - - from safetensors import safe_open - - with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] - sd_keys = list(model_weights.keys()) - for key in sd_keys: - sd.set_lazy_tensor(key, checkpoint, device) - return sd - - -class FusableWeightsMissingError(Exception): - missing_weights: List[str] = [] - - def __init__(self, missing_weights): - self.missing_weights = missing_weights - super().__init__() - - -def load_state_dict_into_model( - model: torch.nn.Module, - state_dict: MutableMapping[str, Any], - architecture: str, - source: str, - distributed_strategy: Optional[str] = None, - checkpoint_sharding: Optional[str] = None, - initial_device: torch.device = torch.device("cpu"), - rank: int = 0, - world_size: int = 0, -) -> None: - """ - This function loads state_dict into model in the most efficient way possible, - and it removes all weights that have been used in model from state_dict - in order to conserve memory. - - Args: - model: The model where the weights are being loaded. - state_dict: The dictionary with all the weights. If it has been mmaped - (for torch.load) or it is an instance of LazySafetensorsDict, - the weights are loaded lazily from disk. - architecture: the model architecture, e.g. llama. See `models.list_models()`. - source: If the weights in the state dict didn't come from an FMS model, - `source` specifies which conversion function might be needed. - See `serialization.list_sources(architecture)` - distributed_strategy: the kind of possibly-distributed model in which we - intend to load these weights. E.g. tp, fsdp, None. Used for weight - sharding. - checkpoint_sharding: the sharding format of the checkpoint. - E.g. layer, tp, fsdp. Used for weight sharding. - initial_device: where the weights will be loaded from disk. - """ - - # 1. Get the adapter from checkpoint sd to fms sd - adapter = _get_adapter(architecture, source) - - # 2. Decide if model needs sharding and how (for now only TP) - needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp" - - # 3. Iterate over the weights and load them into the model - used_keys = set() - sd_keys = list(state_dict.keys()) - with torch.no_grad(): - for key in sd_keys: - if key in used_keys: - continue - used_keys.add(key) - try: - partial_sd = {key: state_dict[key]} - if partial_sd[key].device != initial_device: - partial_sd[key] = partial_sd[key].to(device=initial_device) - fms_partial_sd = adapter(partial_sd) - except FusableWeightsMissingError as e: - for weight in e.missing_weights: - used_keys.add(weight) - partial_sd[weight] = state_dict[weight] - if partial_sd[weight].device != initial_device: - partial_sd[weight] = partial_sd[weight].to( - device=initial_device - ) - fms_partial_sd = adapter(partial_sd) - _load_partial_state_dict( - model, fms_partial_sd, needs_tp_sharding, rank, world_size - ) - for p_key in partial_sd.keys(): - if isinstance(state_dict, ChainMap): - for child_sd in state_dict.maps: - child_sd.pop(p_key, None) - else: - state_dict.pop(p_key) - del partial_sd - del fms_partial_sd - - -def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): - """ - This function copies the correct shard of the weights for a colwise-TP'd module - according to the rank of the process and the world_size. - - Args - ==== - param: torch.nn.Parameter - Parameter that has had TP applied - tensor_value: torch.Tensor - tensor that needs sharding - rank: int - Rank of the current process - world_size: int - Total number of TP processes - """ - # Divide the weight matrix along the first dimension. - output_size_per_partition = param.shape[0] - if not is_bias: - tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), - :, - ] - else: - tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ) - ] - param.copy_(tensor, non_blocking=True) - - -def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): - """ - This function copies the correct shard of the weights for a rowwise-TP'd module - according to the rank of the process and the world_size. - - Args - ==== - param: torch.nn.Parameter - Parameter that has had TP applied - tensor_value: torch.Tensor - tensor that needs sharding - rank: int - Rank of the current process - world_size: int - Total number of TP processes - """ - # Divide the weight matrix along the last dimension. - if not is_bias: - output_size_per_partition = param.shape[1] - tensor = tensor_value[ - :, - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), - ] - param.copy_(tensor, non_blocking=True) - else: - if rank == 0: - _copy_if_present(param, tensor_value) - else: - param.zero_() - - -def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size): - """ - This function copies the correct shard of the weights for a TP'd embedding module - according to the rank of the process and the world_size. - - Args - ==== - param: torch.nn.Parameter - Parameter that has had TP applied - tensor_value: torch.Tensor - tensor that needs sharding - rank: int - Rank of the current process - world_size: int - Total number of TP processes - """ - # Divide the weight matrix along the last dimension. - output_size_per_partition = param.shape[1] - tensor = tensor_value[ - :, - (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), - ] - param.copy_(tensor, non_blocking=True) - - -def _copy_if_present(parameter, tensor_value): - parameter.copy_(tensor_value, non_blocking=True) - - -def _load_partial_state_dict( - model: torch.nn.Module, - state_dict, - needs_tp_sharding: bool, - rank=0, - world_size=1, -): - unused_params = [] - for key, tensor_value in state_dict.items(): - target_module = model - # Find where to put the weight and decide whether it needs TP'ing - key_steps = key.split(".") - prefix = "" - key_step = 0 - tp_module = None - # Navigate the model tree to find the module where the parameter is - # located and whether there is a TPModule in the way in case the - # parameter requires sharding - while key_step < len(key_steps) - 1: - try: - target_module = getattr(target_module, key_steps[key_step]) - if key_step > 0: - prefix += "." - prefix += key_steps[key_step] - key_step += 1 - if isinstance(target_module, Iterable): - target_module = target_module[int(key_steps[key_step])] # type: ignore[index] - prefix += "." + key_steps[key_step] - key_step += 1 - if isinstance(target_module, TPModule): - tp_module = target_module - except AttributeError: - unused_params.append(key) - break - - # Check if target_module has the Parameter/buffer - try: - param = getattr(target_module, key_steps[-1]) - - # If TP sharding is not needed, copy the parameter - # into the model - if not needs_tp_sharding or tp_module is None: - _copy_if_present(param, tensor_value) - elif tp_module is not None: - # Handle TP sharding - if key_steps[-2] in tp_module.colwise_param_names(): - _copy_colwise( - param, - tensor_value, - key_steps[-1] == "bias", - rank, - world_size, - ) - if key_steps[-2] in tp_module.rowwise_param_names(): - _copy_rowwise( - param, - tensor_value, - key_steps[-1] == "bias", - rank, - world_size, - ) - if key_steps[-2] in tp_module.embedding_param_names(): - _copy_embedding( - param, - tensor_value, - rank, - world_size, - ) - except AttributeError: - unused_params.append(key) diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py deleted file mode 100644 index 4bc68bbca7..0000000000 --- a/optimum/habana/distributed/strategy.py +++ /dev/null @@ -1,108 +0,0 @@ -from abc import abstractmethod -from typing import Any, List, Mapping - -import torch -import torch.distributed -from torch import nn - -class DistributedStrategy: - def __init__(self, from_meta=False): - self.from_meta = from_meta - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - """ - Optionally a distributed strategy may distribute modules that are not - numbered layers - """ - return module - - @abstractmethod - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - """ - Distribute each layer as-appropriate - """ - pass - - -class NotDistributed(DistributedStrategy): - def __init__(self, from_meta=False): - super().__init__(from_meta) - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - return module - - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - return block - - -NoOpStrategy = NotDistributed() - - -class DeviceMover(nn.Module): - def __init__(self, module: nn.Module, device): - super().__init__() - self.device = device - # make this wrapper module behave as if it was the wrapped module. - attr = module.__dict__ - attr["module"] = module.to(device) - attr["device"] = device - self.__dict__ = attr - - def forward(self, *args, **kwargs): - device = self.device - args = [ - arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args - ] - kwargs = { - k: ( - kwargs[k].to(device) - if isinstance(kwargs[k], torch.Tensor) - else kwargs[k] - ) - for k in kwargs - } - return self.module(*args, **kwargs) - - -class UniformModelParallelStrategy(DistributedStrategy): - def __init__(self, devices: List[int], num_layers: int, from_meta=False): - super().__init__(from_meta) - num_dev = len(devices) - layers_per_dev = num_layers // num_dev - remainder = num_layers - (layers_per_dev * num_dev) - self.layer_to_device = [0] * num_layers - layer_id = 0 - for dev_idx in range(len(devices)): - for i in range(layers_per_dev): - self.layer_to_device[layer_id] = devices[dev_idx] - layer_id = layer_id + 1 - if remainder > 0: - self.layer_to_device[layer_id] = devices[dev_idx] - layer_id = layer_id + 1 - remainder -= 1 - - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - device = self.layer_to_device[layer] - if self.from_meta: - block.to_empty(device=device) # type: ignore[arg-type] - wrapped = DeviceMover(block, device) - return wrapped - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - if final_layers: - device = self.layer_to_device[len(self.layer_to_device) - 1] - else: - device = self.layer_to_device[0] - if self.from_meta: - return module.to_empty(device=device) # type: ignore[arg-type] - wrapped = DeviceMover(module, device) - return wrapped - - - diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py deleted file mode 100644 index 31f33a79cc..0000000000 --- a/optimum/habana/distributed/tp.py +++ /dev/null @@ -1,84 +0,0 @@ -import itertools -from abc import ABCMeta, abstractmethod -from typing import List, Type - -import torch -import torch.nn as nn -from torch.distributed.distributed_c10d import ProcessGroup - -from optimum.habana.distributed.tensorparallel import ( - apply_colwise_tp, - apply_embedding_tp, - apply_rowwise_tp, -) - - -class TPModule(nn.Module, metaclass=ABCMeta): - """ - This is an abstract class that any nn.Module can implement to enable - Tensor Parallel. On top of inheriting from this class, the TP module - will have to implement list_colwise_weights, list_rowwise_weights, - list_embedding_weights, and import_module for their relevant weights. - Finally, the module must call setup_tp at the end of their __init__ - function. See examples in attention.py, feedforward.py and embedding.py - - """ - - rank: int - world_size: int - - def setup_tp(self, rank: int, world_size: int) -> None: - self.rank = rank - self.world_size = world_size - - def colwise_param_names(self) -> List[str]: - return [] - - def rowwise_param_names(self) -> List[str]: - return [] - - def embedding_param_names(self) -> List[str]: - return [] - - @staticmethod - @abstractmethod - def import_module(module, group: ProcessGroup): - pass - - def import_weights(self, module: nn.Module): - for weight in self.colwise_param_names(): - apply_colwise_tp( - getattr(self, weight), - getattr(module, weight), - self.world_size, - self.rank, - ) - for weight in self.rowwise_param_names(): - apply_rowwise_tp( - getattr(self, weight), - getattr(module, weight), - self.world_size, - self.rank, - ) - for weight in self.embedding_param_names(): - apply_embedding_tp( - getattr(self, weight), - getattr(module, weight), - self.world_size, - self.rank, - ) - tp_sharded_modules = list( - itertools.chain( - self.colwise_param_names(), - self.rowwise_param_names(), - self.embedding_param_names(), - ) - ) - with torch.no_grad(): - for mod_name, module in self.named_children(): - if not mod_name in tp_sharded_modules: - for param_name, param in module.named_parameters(recurse=False): - param.copy_( - getattr(getattr(module, mod_name), param_name), - non_blocking=True, - ) diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py deleted file mode 100644 index 402accb342..0000000000 --- a/optimum/habana/distributed/tp_wrapping.py +++ /dev/null @@ -1,33 +0,0 @@ -import os -from torch import nn -from torch.distributed.distributed_c10d import ProcessGroup - -from optimum.habana.transformers.models.llama.modeling_llama import ( - GaudiLlamaMLP, - TPGaudiLlamaMLP, - GaudiLlamaAttention, - TPGaudiLlamaAttention -) - -# this probably belongs somewhere else but can't go in fms.distribtued b/c -# circular dependency. -def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): - if hasattr(module, "to_tp"): - return module.to_tp(group) - elif isinstance(module, GaudiLlamaAttention): - return TPGaudiLlamaAttention.import_module(module,layer, group) - elif isinstance(module, GaudiLlamaMLP): - return TPGaudiLlamaMLP.import_module(module, group) - else: - return module - - -def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): - wrapped = _tp_wrapped(model, layer_idx, group) - if wrapped is not model: - return wrapped - - for name, layer in model.named_children(): - tp_layer = apply_tp(layer, layer_idx, group) - setattr(model, name, tp_layer) - return model diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 36ef637eed..a3561033fe 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -20,20 +20,10 @@ logger, ) -import copy -from torch.distributed.distributed_c10d import ProcessGroup from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from optimum.habana.distributed.tp import TPModule -from optimum.habana import distributed -from optimum.habana.distributed.tensorparallel import ( - reduce_from_tensor_model_parallel_region, -) - -from optimum.habana.distributed.strategy import DistributedStrategy -from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -198,46 +188,6 @@ def post_mlp_forward(self, x): return self.down_proj.post_all_reduce(x) return x -class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): - def __init__( - self, - config, - group: Optional[ProcessGroup] = None, - ): - assert torch.distributed.is_initialized() - rank, world_size = distributed.rank_and_world(group) - hidden_dim = int(config.hidden_grow_factor * config.hidden_size) - assert ( - hidden_dim % world_size == 0 - ), "Hidden dim must be divisible by world size" - - self.config = copy.deepcopy(config) - self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) - GaudiLlamaMLP.__init__( - self, - self.config - ) - self.setup_tp(rank, world_size) - - def colwise_param_names(self) -> List[str]: - return ["up_proj", "gate_proj"] - - def rowwise_param_names(self) -> List[str]: - return ["down_proj"] - - @staticmethod - def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": - config = copy.deepcopy(glu.config) - config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size - tp_glu = TPGaudiLlamaMLP( - config = config, - group=group - ) - return tp_glu - - def pre_mlp_forward(self, x): - out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) - return reduce_from_tensor_model_parallel_region(out_par) def gaudi_llama_repeat_kv( query_states: torch.Tensor, @@ -606,94 +556,6 @@ def post_attn_forward(self, attn_output): return attn_output -class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, group: Optional[ProcessGroup] = None,): - super().__init__(config, layer_idx) - - assert torch.distributed.is_initialized() - rank, world_size = distributed.rank_and_world(group) - assert ( - config.num_attention_heads % world_size == 0 - ), "The number of heads must be divisible by world size" - self.config = copy.deepcopy(config) - - self.pre_tp_kvheads = config.num_key_value_heads - GaudiLlamaAttention.__init__(self, self.config , layer_idx) - self.config.num_attention_heads = self.config.num_attention_heads // world_size - self.config.num_key_value_heads = ( self.config.num_key_value_heads // world_size) if self.config.num_key_value_heads > 1 else self.config.num_key_value_heads - self.head_dim = config.hidden_size // config.num_attention_heads - self.hidden_size = self.config.hidden_size // world_size - self.num_heads = self.config.num_attention_heads - - self.q_proj = torch.nn.Linear(config.hidden_size, self.config.num_attention_heads * self.head_dim , bias=config.attention_bias) - self.k_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.v_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.o_proj = torch.nn.Linear(self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) - self.norm_factor = 1.0 / math.sqrt(self.head_dim) - self.setup_tp(rank, world_size) - - def colwise_param_names(self) -> List[str]: - colwise_weights = ["q_proj"] - if self.pre_tp_kvheads != 1: - colwise_weights.append("k_proj") - colwise_weights.append("v_proj") - return colwise_weights - - def rowwise_param_names(self) -> List[str]: - return ["o_proj"] - - @staticmethod - def import_module( - mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup - ) -> "TPGaudiLlamaAttention": - tp_mha = TPGaudiLlamaAttention( - config = mha.config, - layer_idx=layer_idx, - group=group - ) - return tp_mha - - def pre_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - token_idx: Optional[torch.Tensor] = None, - attn_softmax_bf16: Optional[bool] = False, - reuse_cache: Optional[bool] = False, - use_flash_attention: Optional[bool] = False, - flash_attention_recompute: Optional[bool] = False, - flash_attention_causal_mask: Optional[bool] = False, - flash_attention_fast_softmax: Optional[bool] = False, - cache_idx: int = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(self, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - cache_position, - token_idx, - attn_softmax_bf16, - reuse_cache, - use_flash_attention, - flash_attention_recompute, - flash_attention_causal_mask, - flash_attention_fast_softmax, - cache_idx, - **kwargs - ) - - hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) - return hidden_states, attn_weights, present_key_value class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -870,16 +732,10 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.distributed_strategy = config.distributed_strategy - config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - layers = [] - for layer_idx in range(config.num_hidden_layers): - layer = GaudiLlamaDecoderLayer(config, layer_idx) - layer = self.distributed_strategy.distribute_layer(layer, layer_idx) - layers.append(layer) - self.layers = torch.nn.ModuleList(layers) - + self.layers = torch.nn.ModuleList( + [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1117,10 +973,6 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ - def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): - config.distributed_strategy = distributed_strategy - super().__init__(config) - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) From aec501831d37725c19b200c898a75902e5b94f5b Mon Sep 17 00:00:00 2001 From: Kalyan Kumar Date: Tue, 30 Jul 2024 21:47:34 +0530 Subject: [PATCH 2/2] Tensor parallel distributed strategy without using deepspeed (#1121) Co-authored-by: Kalyan --- examples/text-generation/README.md | 31 ++ examples/text-generation/run_generation.py | 7 + examples/text-generation/utils.py | 70 ++- optimum/habana/distributed/__init__.py | 29 ++ optimum/habana/distributed/serialization.py | 475 ++++++++++++++++++ optimum/habana/distributed/strategy.py | 134 +++++ optimum/habana/distributed/tensorparallel.py | 19 +- optimum/habana/distributed/tp.py | 101 ++++ optimum/habana/distributed/tp_wrapping.py | 48 ++ .../models/llama/configuration_llama.py | 2 + .../models/llama/modeling_llama.py | 161 +++++- tests/test_text_generation_example.py | 35 +- 12 files changed, 1104 insertions(+), 8 deletions(-) create mode 100644 optimum/habana/distributed/serialization.py create mode 100644 optimum/habana/distributed/strategy.py create mode 100644 optimum/habana/distributed/tp.py create mode 100644 optimum/habana/distributed/tp_wrapping.py diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 451a91e082..6597484f06 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,6 +264,37 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. +### Running with tensor-parallel strategy + +> [!NOTE] +> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. + +> [!WARNING] +> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. + +To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the +command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. + +You will also need to add `--torch_compile` and `--parallel_strategy="tp"` in your command. + +Here is an example: +```bash +PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--trim_logits \ +--use_kv_cache \ +--attn_softmax_bf16 \ +--bf16 \ +--bucket_internal \ +--bucket_size=128 \ +--use_flash_attention \ +--flash_attention_recompute \ +--batch_size 246 \ +--max_input_tokens 2048 \ +--max_new_tokens 2048 \ +--torch_compile \ +--parallel_strategy="tp" +``` ### Running with FP8 diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 6503cb7003..e438273926 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -333,6 +333,13 @@ def __call__(self, parser, namespace, values, option_string=None): action="store_true", help="Whether to load model from hugging face checkpoint.", ) + parser.add_argument( + "--parallel_strategy", + type=str, + choices=["tp", "none"], # Add other strategies as needed + default="none", + help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index f8e916e0a1..a5bba9826f 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -294,6 +294,72 @@ def setup_model(args, model_dtype, model_kwargs, logger): return model, assistant_model +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir): + from typing import Any, MutableMapping + + from optimum.habana.distributed import serialization + from optimum.habana.distributed.strategy import TensorParallelStrategy + + logger.info("Multi-device run.") + + assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG" + assert args.assistant_model is None, "Assistant model must be None" + + from torch import distributed as dist + + if args.device == "hpu": + dist.init_process_group(backend="hccl") + else: + assert False, "Supports TP only on HPU" + + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + logger.info("Creating Model") + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + model_kwargs = {} + model_kwargs["parallel_strategy"] = TensorParallelStrategy() + model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) + + initial_device = torch.device("cpu") + source = "hf" + checkpoint_sharding = None + lazy_sd: MutableMapping[str, Any] = {} + logger.info("Loading Checkpoints") + lazy_sd = serialization.load_state_dict( + cache_dir, + source=source, + distributed_strategy=args.parallel_strategy, + checkpoint_sharding=None, + initial_device=initial_device, + rank=args.global_rank, + world_size=args.world_size, + ) + architecture = "llama" + if len(lazy_sd): + serialization.load_state_dict_into_model( + model, + lazy_sd, + architecture, + source, + args.parallel_strategy, + checkpoint_sharding, + initial_device, + args.local_rank, + args.world_size, + ) + + model = model.eval().to(args.device) + + if args.use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + model = wrap_in_hpu_graph(model) + + if args.torch_compile and model.config.model_type == "llama": + model = get_torch_compiled_model(model) + + return model, args.assistant_model + + def setup_distributed_model(args, model_dtype, model_kwargs, logger): import deepspeed @@ -535,7 +601,7 @@ def initialize_model(args, logger): setup_env(args) setup_device(args) set_seed(args.seed) - get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) + cache_dir = get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) if args.assistant_model is not None: get_repo_root(args.assistant_model, local_rank=args.local_rank, token=args.token) use_deepspeed = args.world_size > 0 @@ -560,6 +626,8 @@ def initialize_model(args, logger): setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed else setup_distributed_model(args, model_dtype, model_kwargs, logger) + if not args.parallel_strategy == "tp" + else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 2dedd7333d..af269ee68c 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,2 +1,31 @@ +import os + +import torch + from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients + + +def rank_and_world(group=None): + """ + Returns (rank, world_size) from the optionally-specified group, otherwise + from the default group, or if non-distributed just returns (0, 1) + """ + if torch.distributed.is_initialized() and group is None: + group = torch.distributed.GroupMember.WORLD + + if group is None: + world_size = 1 + rank = 0 + else: + world_size = group.size() + rank = group.rank() + + return rank, world_size + + +_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) + + +def local_rank(): + return _LOCAL_RANK diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py new file mode 100644 index 0000000000..bf59fb2445 --- /dev/null +++ b/optimum/habana/distributed/serialization.py @@ -0,0 +1,475 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +import collections +import os +from collections import ChainMap +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union + +import torch + +from .tp import TPModule + + +__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} + + +def register_adapter( + architecture: str, + source: str, + adapter: Callable[[Mapping], Mapping], +): + """ + Registers a state dict adapter to be available to the (de) serialization + API. + + Args: + architecture: The name of the model architecture, e.g. 'llama' + source: A label representing the format of the weights to be converted. + E.g. 'hf' + adapter: the class of the adapter. The class must accept one constructor + parameter, which will be a state dict (`OrderedDict`) + """ + sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {} + if architecture in __adapters: + sources = __adapters[architecture] + + if source in sources: + raise KeyError(f"Variant {source} already registered for architecture {architecture}") + + sources[source] = adapter + __adapters[architecture] = sources + + +def list_sources(architecture: str): + """ + Lists available sources (attribute formats) of a model architecture. + E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf'] + Args: + architecture: one of the registered architectures returned by + `models.list_models()`. + """ + if architecture not in __adapters: + return [] + return list(__adapters[architecture].keys()) + + +def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if source is None or architecture not in __adapters or source not in __adapters[architecture]: + # if no adapter is registered, assume the attributes are already in + # fms format. + # should we raise an error here instead? + return lambda x: x + else: + return __adapters[architecture][source] + + +def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Convert a state dict to FMS format, using an adapter specified by name. + + Args: + architecture: one of the architectures from `models.list_models()`. + E.g. llama. + source: A reference to an attribute format + state_dict: the model.state_dict() to be converted/adapted. + """ + # sometimes we only load onto rank 0 so may not have a state_dict here. + if not len(state_dict): + return state_dict + adapter = _get_adapter(architecture, source) + adapted = adapter(state_dict) + return adapted + + +def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: + from safetensors import safe_open # type: ignore[import-untyped] + + with torch.no_grad(): + with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + return model_weights.get_tensor(key) + + +class LazySafetensorsDict(collections.UserDict): + def set_lazy_tensor(self, key, file, device): + super().__setitem__(key, lambda: _get_safetensors_item(key, file, device)) + + def __getitem__(self, key): + lazy_tensor = super().__getitem__(key) + if callable(lazy_tensor): + lazy_tensor = lazy_tensor() + super().__setitem__(key, lazy_tensor) + return lazy_tensor + + +def load_state_dict( + model_path: Union[str, Path], + *, + source: Optional[str] = None, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 1, +) -> MutableMapping[str, Any]: + """ + Validates that the file(s) found at a checkpoint path are compatible with + the intended (possibly distributed) use-case, and returns a lazy loading + state dict if possible (some formats may not support that). + + If model_path is a directory, it'll try to load models based on the source + (e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't + been registered, it'll try .safetensors, .pth, and .bin. + + Args: + model_path: the path to find the weights. If not set, return None. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for + validation. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. + initial_device: where the state dict will be loaded if not lazy. + If meta, return empty dict. + """ + if model_path is None or initial_device.type == "meta": + return {} + if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: + raise ValueError("FSDP checkpoints can only be loaded into an FSDP model") + if checkpoint_sharding == "tp" and distributed_strategy != "tp": + raise ValueError("TP checkpoints can only be loaded into a TP model") + + # Before creating the Path object, check if model_path has a glob pattern + if isinstance(model_path, str): + model_path, sep, glob_pattern = model_path.partition("*") + else: + sep = "" + glob_pattern = "" + glob_pattern = sep + glob_pattern + + model_path = Path(os.path.expanduser(model_path)) + + checkpoints = [] + + if model_path.is_dir(): + if glob_pattern != "": + glob_pattern_list = [glob_pattern] + elif source == "meta": + glob_pattern_list = ["*.pth", "*.safetensors"] + elif source == "hf": + glob_pattern_list = ["*.bin", "*.safetensors"] + else: + glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"] + for glob_pattern_possibility in glob_pattern_list: + file_list = list(model_path.glob(glob_pattern_possibility)) + if len(file_list) > 0: + checkpoints = sorted(file_list) + break + + if model_path.is_file(): + checkpoints = [model_path] + + # Check if we found some files + assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}" + + if checkpoint_sharding is not None and checkpoint_sharding != "layer": + assert ( + world_size == len(checkpoints) + ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" + + checkpoints = [checkpoints[rank]] + + # if there's only one checkpoint for fsdp/hsdp, load it only into rank zero + # and it will be distributed by the FSDP `sync_module_states` parameter + if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}: + if rank == 0: + checkpoints = [checkpoints[0]] + else: + return {} + + checkpoint_sds = [] + if checkpoints[0].suffix == ".safetensors": + for ckp in checkpoints: + checkpoint_sds.append( + _load_safetensors_state_dict( + ckp, + initial_device, + ) + ) + else: + with torch.no_grad(): + checkpoint_sds = [ + torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints + ] + return ChainMap(*checkpoint_sds) + + +def _load_safetensors_state_dict( + checkpoint: Path, + device: torch.device, +): + sd = LazySafetensorsDict() + + from safetensors import safe_open + + with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + sd_keys = list(model_weights.keys()) + for key in sd_keys: + sd.set_lazy_tensor(key, checkpoint, device) + return sd + + +class FusableWeightsMissingError(Exception): + missing_weights: List[str] = [] + + def __init__(self, missing_weights): + self.missing_weights = missing_weights + super().__init__() + + +def load_state_dict_into_model( + model: torch.nn.Module, + state_dict: MutableMapping[str, Any], + architecture: str, + source: str, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 0, +) -> None: + """ + This function loads state_dict into model in the most efficient way possible, + and it removes all weights that have been used in model from state_dict + in order to conserve memory. + + Args: + model: The model where the weights are being loaded. + state_dict: The dictionary with all the weights. If it has been mmaped + (for torch.load) or it is an instance of LazySafetensorsDict, + the weights are loaded lazily from disk. + architecture: the model architecture, e.g. llama. See `models.list_models()`. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for weight + sharding. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. Used for weight sharding. + initial_device: where the weights will be loaded from disk. + """ + + # 1. Get the adapter from checkpoint sd to fms sd + adapter = _get_adapter(architecture, source) + + # 2. Decide if model needs sharding and how (for now only TP) + needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp" + + # 3. Iterate over the weights and load them into the model + used_keys = set() + sd_keys = list(state_dict.keys()) + with torch.no_grad(): + for key in sd_keys: + if key in used_keys: + continue + used_keys.add(key) + try: + partial_sd = {key: state_dict[key]} + if partial_sd[key].device != initial_device: + partial_sd[key] = partial_sd[key].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + except FusableWeightsMissingError as e: + for weight in e.missing_weights: + used_keys.add(weight) + partial_sd[weight] = state_dict[weight] + if partial_sd[weight].device != initial_device: + partial_sd[weight] = partial_sd[weight].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + _load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size) + for p_key in partial_sd.keys(): + if isinstance(state_dict, ChainMap): + for child_sd in state_dict.maps: + child_sd.pop(p_key, None) + else: + state_dict.pop(p_key) + del partial_sd + del fms_partial_sd + + +def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a colwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the first dimension. + output_size_per_partition = param.shape[0] + if not is_bias: + tensor = tensor_value[ + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + :, + ] + else: + tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)] + param.copy_(tensor, non_blocking=True) + + +def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a rowwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + if not is_bias: + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + else: + if rank == 0: + _copy_if_present(param, tensor_value) + else: + param.zero_() + + +def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size): + """ + This function copies the correct shard of the weights for a TP'd embedding module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_if_present(parameter, tensor_value): + parameter.copy_(tensor_value, non_blocking=True) + + +def _load_partial_state_dict( + model: torch.nn.Module, + state_dict, + needs_tp_sharding: bool, + rank=0, + world_size=1, +): + unused_params = [] + for key, tensor_value in state_dict.items(): + target_module = model + # Find where to put the weight and decide whether it needs TP'ing + key_steps = key.split(".") + prefix = "" + key_step = 0 + tp_module = None + # Navigate the model tree to find the module where the parameter is + # located and whether there is a TPModule in the way in case the + # parameter requires sharding + while key_step < len(key_steps) - 1: + try: + target_module = getattr(target_module, key_steps[key_step]) + if key_step > 0: + prefix += "." + prefix += key_steps[key_step] + key_step += 1 + if isinstance(target_module, Iterable): + target_module = target_module[int(key_steps[key_step])] # type: ignore[index] + prefix += "." + key_steps[key_step] + key_step += 1 + if isinstance(target_module, TPModule): + tp_module = target_module + except AttributeError: + unused_params.append(key) + break + + # Check if target_module has the Parameter/buffer + try: + param = getattr(target_module, key_steps[-1]) + + # If TP sharding is not needed, copy the parameter + # into the model + if not needs_tp_sharding or tp_module is None: + _copy_if_present(param, tensor_value) + elif tp_module is not None: + # Handle TP sharding + if key_steps[-2] in tp_module.colwise_param_names(): + _copy_colwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.rowwise_param_names(): + _copy_rowwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.embedding_param_names(): + _copy_embedding( + param, + tensor_value, + rank, + world_size, + ) + except AttributeError: + unused_params.append(key) diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py new file mode 100644 index 0000000000..91b3f00232 --- /dev/null +++ b/optimum/habana/distributed/strategy.py @@ -0,0 +1,134 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +from abc import abstractmethod +from typing import List + +import torch +import torch.distributed +from torch import nn + + +class DistributedStrategy: + def __init__(self, from_meta=False): + self.from_meta = from_meta + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + """ + Optionally a distributed strategy may distribute modules that are not + numbered layers + """ + return module + + @abstractmethod + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + """ + Distribute each layer as-appropriate + """ + pass + + +class NotDistributed(DistributedStrategy): + def __init__(self, from_meta=False): + super().__init__(from_meta) + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + return module + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return block + + +NoOpStrategy = NotDistributed() + + +class DeviceMover(nn.Module): + def __init__(self, module: nn.Module, device): + super().__init__() + self.device = device + # make this wrapper module behave as if it was the wrapped module. + attr = module.__dict__ + attr["module"] = module.to(device) + attr["device"] = device + self.__dict__ = attr + + def forward(self, *args, **kwargs): + device = self.device + args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: (kwargs[k].to(device) if isinstance(kwargs[k], torch.Tensor) else kwargs[k]) for k in kwargs} + return self.module(*args, **kwargs) + + +class UniformModelParallelStrategy(DistributedStrategy): + def __init__(self, devices: List[int], num_layers: int, from_meta=False): + super().__init__(from_meta) + num_dev = len(devices) + layers_per_dev = num_layers // num_dev + remainder = num_layers - (layers_per_dev * num_dev) + self.layer_to_device = [0] * num_layers + layer_id = 0 + for dev_idx in range(len(devices)): + for i in range(layers_per_dev): + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + if remainder > 0: + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + remainder -= 1 + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + device = self.layer_to_device[layer] + if self.from_meta: + block.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(block, device) + return wrapped + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + if final_layers: + device = self.layer_to_device[len(self.layer_to_device) - 1] + else: + device = self.layer_to_device[0] + if self.from_meta: + return module.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(module, device) + return wrapped + + +class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state["group"] = None # Remove ProcessGroup from state + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py index 61ef99ca02..ba423175b4 100644 --- a/optimum/habana/distributed/tensorparallel.py +++ b/optimum/habana/distributed/tensorparallel.py @@ -1,10 +1,25 @@ -# mypy: disable-error-code="method-assign,misc" +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + import torch import torch._inductor.ir as ir import torch._inductor.lowering as lowering import torch.distributed as dist -import torch.distributed._functional_collectives as funcol from torch import nn #This needs to be fixed Issues can be tracked at - SW-192548 diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py new file mode 100644 index 0000000000..c4f156fa61 --- /dev/null +++ b/optimum/habana/distributed/tp.py @@ -0,0 +1,101 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +import itertools +from abc import ABCMeta, abstractmethod +from typing import List + +import torch +import torch.nn as nn +from torch.distributed.distributed_c10d import ProcessGroup + +from .tensorparallel import ( + apply_colwise_tp, + apply_embedding_tp, + apply_rowwise_tp, +) + + +class TPModule(nn.Module, metaclass=ABCMeta): + """ + This is an abstract class that any nn.Module can implement to enable + Tensor Parallel. On top of inheriting from this class, the TP module + will have to implement list_colwise_weights, list_rowwise_weights, + list_embedding_weights, and import_module for their relevant weights. + Finally, the module must call setup_tp at the end of their __init__ + function. See examples in attention.py, feedforward.py and embedding.py + + """ + + rank: int + world_size: int + + def setup_tp(self, rank: int, world_size: int) -> None: + self.rank = rank + self.world_size = world_size + + def colwise_param_names(self) -> List[str]: + return [] + + def rowwise_param_names(self) -> List[str]: + return [] + + def embedding_param_names(self) -> List[str]: + return [] + + @staticmethod + @abstractmethod + def import_module(module, group: ProcessGroup): + pass + + def import_weights(self, module: nn.Module): + for weight in self.colwise_param_names(): + apply_colwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.rowwise_param_names(): + apply_rowwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.embedding_param_names(): + apply_embedding_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + tp_sharded_modules = list( + itertools.chain( + self.colwise_param_names(), + self.rowwise_param_names(), + self.embedding_param_names(), + ) + ) + with torch.no_grad(): + for mod_name, module in self.named_children(): + if mod_name not in tp_sharded_modules: + for param_name, param in module.named_parameters(recurse=False): + param.copy_( + getattr(getattr(module, mod_name), param_name), + non_blocking=True, + ) diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py new file mode 100644 index 0000000000..761fa7bff4 --- /dev/null +++ b/optimum/habana/distributed/tp_wrapping.py @@ -0,0 +1,48 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup + +from ..transformers.models.llama.modeling_llama import ( + GaudiLlamaAttention, + GaudiLlamaMLP, + TPGaudiLlamaAttention, + TPGaudiLlamaMLP, +) + + +def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): + if hasattr(module, "to_tp"): + return module.to_tp(group) + elif isinstance(module, GaudiLlamaAttention): + return TPGaudiLlamaAttention.import_module(module, layer, group) + elif isinstance(module, GaudiLlamaMLP): + return TPGaudiLlamaMLP.import_module(module, group) + else: + return module + + +def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): + wrapped = _tp_wrapped(model, layer_idx, group) + if wrapped is not model: + return wrapped + + for name, layer in model.named_children(): + tp_layer = apply_tp(layer, layer_idx, group) + setattr(model, name, tp_layer) + return model diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index dcba1c0738..12ad78e29a 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -27,6 +27,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, fused_qkv=False, + parallel_strategy=None, **kwargs, ): super().__init__( @@ -55,3 +56,4 @@ def __init__( self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv + self.parallel_strategy = parallel_strategy diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index a3561033fe..9366a7bee1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import copy import math import os import warnings @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.distributed.distributed_c10d import ProcessGroup from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -20,6 +22,12 @@ logger, ) +from .... import distributed +from ....distributed.strategy import DistributedStrategy, NoOpStrategy +from ....distributed.tensorparallel import ( + reduce_from_tensor_model_parallel_region, +) +from ....distributed.tp import TPModule from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) @@ -189,6 +197,40 @@ def post_mlp_forward(self, x): return x +class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): + def __init__( + self, + config, + group: Optional[ProcessGroup] = None, + ): + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + hidden_dim = int(config.hidden_grow_factor * config.hidden_size) + assert hidden_dim % world_size == 0, "Hidden dim must be divisible by world size" + + self.config = copy.deepcopy(config) + self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) + GaudiLlamaMLP.__init__(self, self.config) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + return ["up_proj", "gate_proj"] + + def rowwise_param_names(self) -> List[str]: + return ["down_proj"] + + @staticmethod + def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": + config = copy.deepcopy(glu.config) + config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size + tp_glu = TPGaudiLlamaMLP(config=config, group=group) + return tp_glu + + def pre_mlp_forward(self, x): + out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) + return reduce_from_tensor_model_parallel_region(out_par) + + def gaudi_llama_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -556,6 +598,107 @@ def post_attn_forward(self, attn_output): return attn_output +class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + group: Optional[ProcessGroup] = None, + ): + super().__init__(config, layer_idx) + + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + assert config.num_attention_heads % world_size == 0, "The number of heads must be divisible by world size" + self.config = copy.deepcopy(config) + + self.pre_tp_kvheads = config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config, layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( + (self.config.num_key_value_heads // world_size) + if self.config.num_key_value_heads > 1 + else self.config.num_key_value_heads + ) + self.head_dim = config.hidden_size // config.num_attention_heads + self.hidden_size = self.config.hidden_size // world_size + self.num_heads = self.config.num_attention_heads + + self.q_proj = torch.nn.Linear( + config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = torch.nn.Linear( + self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + colwise_weights = ["q_proj"] + if self.pre_tp_kvheads != 1: + colwise_weights.append("k_proj") + colwise_weights.append("v_proj") + return colwise_weights + + def rowwise_param_names(self) -> List[str]: + return ["o_proj"] + + @staticmethod + def import_module(mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention(config=mha.config, layer_idx=layer_idx, group=group) + return tp_mha + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + valid_sequence_lengths, + cache_idx, + **kwargs, + ) + + hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) + return hidden_states, attn_weights, present_key_value + + class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -731,11 +874,17 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList( - [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + layers = [] + for layer_idx in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, layer_idx) + if config.parallel_strategy is not None: + layer = config.parallel_strategy.distribute_layer(layer, layer_idx) + layers.append(layer) + self.layers = torch.nn.ModuleList(layers) + # parallel_strategy is not JSON serializable + config.parallel_strategy = None + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -973,6 +1122,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ + def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy): + config.parallel_strategy = parallel_strategy + super().__init__(config) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index af4cf716d2..76b6736a93 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -71,6 +71,9 @@ "torch_compile_distributed": [ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], + "distributed_tp": [ + ("meta-llama/Llama-2-7b-hf", 1345.2369318328463), + ], } else: # Gaudi1 CI baselines @@ -101,6 +104,7 @@ ], "torch_compile": [], "torch_compile_distributed": [], + "distributed_tp": [], } @@ -117,6 +121,7 @@ def _test_text_generation( gptq: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, + parallel_strategy: str = None, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -128,6 +133,11 @@ def _test_text_generation( "--use_deepspeed", f"--world_size {world_size}", ] + elif parallel_strategy == "tp": + command += [ + f"{path_to_example_dir / 'gaudi_spawn.py'}", + f"--world_size {world_size}", + ] command += [ f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", @@ -143,11 +153,14 @@ def _test_text_generation( if "falcon" in model_name.lower(): command += ["--use_flash_attention", "--flash_attention_causal_mask"] - if reuse_cache or torch_compile: + if (reuse_cache or torch_compile) and not parallel_strategy == "tp": command += ["--reuse_cache"] if torch_compile: command += ["--torch_compile"] + if parallel_strategy == "tp": + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: @@ -194,9 +207,15 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if gptq: command += ["--gptq"] + if parallel_strategy is not None: + command += [ + f"--parallel_strategy={parallel_strategy}", + ] + with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") command.append(f"--token {token.value}") @@ -324,3 +343,17 @@ def test_text_generation_torch_compile(model_name: str, baseline: float, token: def test_text_generation_torch_compile_distributed(model_name: str, baseline: float, token: str): world_size = 8 _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) + +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"]) +def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str): + world_size = 8 + _test_text_generation( + model_name, + baseline, + token, + batch_size=64, + max_input_tokens=128, + world_size=world_size, + torch_compile=True, + parallel_strategy="tp", + )