From 080a6a17a099ec9ac161275bea3085d3e5dcf347 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 28 Jun 2024 09:41:53 +0300 Subject: [PATCH 01/10] TP reference - ibm foundation-model-stack --- examples/text-generation/run_generation.py | 8 + examples/text-generation/utils.py | 98 +++- optimum/habana/distributed/__init__.py | 26 + optimum/habana/distributed/serialization.py | 489 ++++++++++++++++++ optimum/habana/distributed/strategy.py | 112 ++++ optimum/habana/distributed/tensorparallel.py | 229 ++++++++ optimum/habana/distributed/tp.py | 84 +++ optimum/habana/distributed/tp_wrapping.py | 33 ++ .../models/llama/modeling_llama.py | 161 +++++- 9 files changed, 1232 insertions(+), 8 deletions(-) create mode 100644 optimum/habana/distributed/serialization.py create mode 100644 optimum/habana/distributed/strategy.py create mode 100644 optimum/habana/distributed/tensorparallel.py create mode 100644 optimum/habana/distributed/tp.py create mode 100644 optimum/habana/distributed/tp_wrapping.py diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 1a020555e5..2d1c4bcc07 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -287,6 +287,14 @@ def setup_parser(parser): action="store_true", help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) + parser.add_argument( + "--distributed_strategy", + type=str, + choices=["tp", "none"], # Add other strategies as needed + default="none", + help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + ) + args = parser.parse_args() if args.torch_compile: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index fa1946b914..494ffc11d5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -43,7 +43,6 @@ set_seed, ) - def adjust_batch(batch, size): curr_size = batch["input_ids"].shape[1] if curr_size >= size: @@ -247,6 +246,100 @@ def setup_model(args, model_dtype, model_kwargs, logger): # assistant_model = get_torch_compiled_model(assistant_model) return model, assistant_model +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): + + from optimum.habana.distributed import serialization + from typing import Any, MutableMapping + + from optimum.habana.distributed import tp_wrapping + from optimum.habana.distributed.strategy import DistributedStrategy + from torch import nn + + class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state['group'] = None # Remove ProcessGroup from state + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize + + logger.info("Multi-device run.") + + assert args.assistant_model is None, "Assistant model must be None" + + from torch import distributed as dist + if args.device == 'hpu': + import habana_frameworks.torch.distributed.hccl + dist.init_process_group(backend='hccl') + else: + dist.init_process_group() + + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) + model_kwargs={} + model_kwargs["distributed_strategy"] = TensorParallelStrategy() + model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) + + initial_device = torch.device("cpu") + source="hf" + checkpoint_sharding=None + lazy_sd: MutableMapping[str, Any] = {} + lazy_sd = serialization.load_state_dict( + args.model_name_or_path, + source=source, + distributed_strategy=args.distributed_strategy, + checkpoint_sharding=None, + initial_device=initial_device, + rank=args.global_rank, + world_size=args.world_size, + ) + architecture="llama" + if len(lazy_sd): + serialization.load_state_dict_into_model( + model, + lazy_sd, + architecture, + source, + args.distributed_strategy, + checkpoint_sharding, + initial_device, + args.local_rank, + args.world_size, + ) + + if args.quant_config: + model = setup_quantization(model, args) + + model = model.eval().to(args.device) + + if args.use_hpu_graphs: + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": + model = wrap_in_hpu_graph(model, hash_with_views=False) + else: + model = wrap_in_hpu_graph(model) + + if args.torch_compile and model.config.model_type == "llama": + model = get_torch_compiled_model(model) + + return model, args.assistant_model + def setup_distributed_model(args, model_dtype, model_kwargs, logger): import deepspeed @@ -521,7 +614,8 @@ def initialize_model(args, logger): model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed - else setup_distributed_model(args, model_dtype, model_kwargs, logger) + else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp" + else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 2dedd7333d..12edd6620e 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,2 +1,28 @@ from .distributed_runner import DistributedRunner from .fast_ddp import all_reduce_gradients +import os +import torch + +def rank_and_world(group=None): + """ + Returns (rank, world_size) from the optionally-specified group, otherwise + from the default group, or if non-distributed just returns (0, 1) + """ + if torch.distributed.is_initialized() and group is None: + group = torch.distributed.GroupMember.WORLD + + if group is None: + world_size = 1 + rank = 0 + else: + world_size = group.size() + rank = group.rank() + + return rank, world_size + + +_LOCAL_RANK = int(os.getenv("LOCAL_RANK", 0)) + + +def local_rank(): + return _LOCAL_RANK diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py new file mode 100644 index 0000000000..c543ab20bd --- /dev/null +++ b/optimum/habana/distributed/serialization.py @@ -0,0 +1,489 @@ +import collections +import os +from collections import ChainMap +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Callable, List, Mapping, MutableMapping, Optional, Union + +import torch + +from optimum.habana.distributed.tp import TPModule + + +__adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} + + +def register_adapter( + architecture: str, + source: str, + adapter: Callable[[Mapping], Mapping], +): + """ + Registers a state dict adapter to be available to the (de) serialization + API. + + Args: + architecture: The name of the model architecture, e.g. 'llama' + source: A label representing the format of the weights to be converted. + E.g. 'hf' + adapter: the class of the adapter. The class must accept one constructor + parameter, which will be a state dict (`OrderedDict`) + """ + sources: MutableMapping[str, Callable[[Mapping], Mapping]] = {} + if architecture in __adapters: + sources = __adapters[architecture] + + if source in sources: + raise KeyError( + f"Variant {source} already registered for architecture {architecture}" + ) + + sources[source] = adapter + __adapters[architecture] = sources + + +def list_sources(architecture: str): + """ + Lists available sources (attribute formats) of a model architecture. + E.g. `models.list_variants('llama')` -> ['meta', 'fms', 'hf'] + Args: + architecture: one of the registered architectures returned by + `models.list_models()`. + """ + if architecture not in __adapters: + return [] + return list(__adapters[architecture].keys()) + + +def _get_adapter( + architecture: str, source: Optional[str] +) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if ( + source is None + or architecture not in __adapters + or source not in __adapters[architecture] + ): + # if no adapter is registered, assume the attributes are already in + # fms format. + # should we raise an error here instead? + return lambda x: x + else: + return __adapters[architecture][source] + + +def get_adapted( + architecture: str, source: Optional[str], state_dict: Mapping[str, Any] +) -> Mapping[str, Any]: + """ + Convert a state dict to FMS format, using an adapter specified by name. + + Args: + architecture: one of the architectures from `models.list_models()`. + E.g. llama. + source: A reference to an attribute format + state_dict: the model.state_dict() to be converted/adapted. + """ + # sometimes we only load onto rank 0 so may not have a state_dict here. + if not len(state_dict): + return state_dict + adapter = _get_adapter(architecture, source) + adapted = adapter(state_dict) + return adapted + + +# `models` imports each model class, causing models and adapters to be registered. +# down here to avoid circular dependencies. +# from fms import models + + +def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: + from safetensors import safe_open # type: ignore[import-untyped] + + with torch.no_grad(): + with safe_open( + file, framework="pt", device=str(device) + ) as model_weights: # type: ignore[attr-defined] + return model_weights.get_tensor(key) + + +class LazySafetensorsDict(collections.UserDict): + def set_lazy_tensor(self, key, file, device): + super().__setitem__(key, lambda: _get_safetensors_item(key, file, device)) + + def __getitem__(self, key): + lazy_tensor = super().__getitem__(key) + if callable(lazy_tensor): + lazy_tensor = lazy_tensor() + super().__setitem__(key, lazy_tensor) + return lazy_tensor + + +def load_state_dict( + model_path: Union[str, Path], + *, + source: Optional[str] = None, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 1, +) -> MutableMapping[str, Any]: + """ + Validates that the file(s) found at a checkpoint path are compatible with + the intended (possibly distributed) use-case, and returns a lazy loading + state dict if possible (some formats may not support that). + + If model_path is a directory, it'll try to load models based on the source + (e.g. .bin for HF, .pth for Meta), and, if no source is specified or hasn't + been registered, it'll try .safetensors, .pth, and .bin. + + Args: + model_path: the path to find the weights. If not set, return None. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for + validation. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. + initial_device: where the state dict will be loaded if not lazy. + If meta, return empty dict. + """ + if model_path is None or initial_device.type == "meta": + return {} + if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: + raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model") + if checkpoint_sharding == "tp" and distributed_strategy != "tp": + raise ValueError("TP checkpoints can only be loaded into a TP model") + + # Before creating the Path object, check if model_path has a glob pattern + if isinstance(model_path, str): + model_path, sep, glob_pattern = model_path.partition("*") + else: + sep = "" + glob_pattern = "" + glob_pattern = sep + glob_pattern + + model_path = Path(os.path.expanduser(model_path)) + + checkpoints = [] + + if model_path.is_dir(): + if glob_pattern != "": + glob_pattern_list = [glob_pattern] + elif source == "meta": + glob_pattern_list = ["*.pth", "*.safetensors"] + elif source == "hf": + glob_pattern_list = ["*.bin", "*.safetensors"] + else: + glob_pattern_list = ["*.safetensors", "*.pth", "*.bin"] + for glob_pattern_possibility in glob_pattern_list: + file_list = list(model_path.glob(glob_pattern_possibility)) + if len(file_list) > 0: + checkpoints = sorted(file_list) + break + + if model_path.is_file(): + checkpoints = [model_path] + + # Check if we found some files + assert ( + len(checkpoints) > 0 + ), f"Can't find the requested checkpoint data at {model_path}" + + if checkpoint_sharding is not None and checkpoint_sharding != "layer": + assert world_size == len( + checkpoints + ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" + + checkpoints = [checkpoints[rank]] + + # if there's only one checkpoint for fsdp/hsdp, load it only into rank zero + # and it will be distributed by the FSDP `sync_module_states` parameter + if checkpoint_sharding is None and distributed_strategy in {"hsdp", "fsdp"}: + if rank == 0: + checkpoints = [checkpoints[0]] + else: + return {} + + checkpoint_sds = [] + if checkpoints[0].suffix == ".safetensors": + for ckp in checkpoints: + checkpoint_sds.append( + _load_safetensors_state_dict( + ckp, + initial_device, + ) + ) + else: + with torch.no_grad(): + checkpoint_sds = [ + torch.load(str(ckpt_path), map_location=initial_device, mmap=True) for ckpt_path in checkpoints + ] + return ChainMap(*checkpoint_sds) + + +def _load_safetensors_state_dict( + checkpoint: Path, + device: torch.device, +): + sd = LazySafetensorsDict() + + from safetensors import safe_open + + with safe_open(checkpoint, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] + sd_keys = list(model_weights.keys()) + for key in sd_keys: + sd.set_lazy_tensor(key, checkpoint, device) + return sd + + +class FusableWeightsMissingError(Exception): + missing_weights: List[str] = [] + + def __init__(self, missing_weights): + self.missing_weights = missing_weights + super().__init__() + + +def load_state_dict_into_model( + model: torch.nn.Module, + state_dict: MutableMapping[str, Any], + architecture: str, + source: str, + distributed_strategy: Optional[str] = None, + checkpoint_sharding: Optional[str] = None, + initial_device: torch.device = torch.device("cpu"), + rank: int = 0, + world_size: int = 0, +) -> None: + """ + This function loads state_dict into model in the most efficient way possible, + and it removes all weights that have been used in model from state_dict + in order to conserve memory. + + Args: + model: The model where the weights are being loaded. + state_dict: The dictionary with all the weights. If it has been mmaped + (for torch.load) or it is an instance of LazySafetensorsDict, + the weights are loaded lazily from disk. + architecture: the model architecture, e.g. llama. See `models.list_models()`. + source: If the weights in the state dict didn't come from an FMS model, + `source` specifies which conversion function might be needed. + See `serialization.list_sources(architecture)` + distributed_strategy: the kind of possibly-distributed model in which we + intend to load these weights. E.g. tp, fsdp, None. Used for weight + sharding. + checkpoint_sharding: the sharding format of the checkpoint. + E.g. layer, tp, fsdp. Used for weight sharding. + initial_device: where the weights will be loaded from disk. + """ + + # 1. Get the adapter from checkpoint sd to fms sd + adapter = _get_adapter(architecture, source) + + # 2. Decide if model needs sharding and how (for now only TP) + needs_tp_sharding = checkpoint_sharding != "tp" and distributed_strategy == "tp" + + # 3. Iterate over the weights and load them into the model + used_keys = set() + sd_keys = list(state_dict.keys()) + with torch.no_grad(): + for key in sd_keys: + if key in used_keys: + continue + used_keys.add(key) + try: + partial_sd = {key: state_dict[key]} + if partial_sd[key].device != initial_device: + partial_sd[key] = partial_sd[key].to(device=initial_device) + fms_partial_sd = adapter(partial_sd) + except FusableWeightsMissingError as e: + for weight in e.missing_weights: + used_keys.add(weight) + partial_sd[weight] = state_dict[weight] + if partial_sd[weight].device != initial_device: + partial_sd[weight] = partial_sd[weight].to( + device=initial_device + ) + fms_partial_sd = adapter(partial_sd) + _load_partial_state_dict( + model, fms_partial_sd, needs_tp_sharding, rank, world_size + ) + for p_key in partial_sd.keys(): + if isinstance(state_dict, ChainMap): + for child_sd in state_dict.maps: + child_sd.pop(p_key, None) + else: + state_dict.pop(p_key) + del partial_sd + del fms_partial_sd + + +def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a colwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the first dimension. + output_size_per_partition = param.shape[0] + if not is_bias: + tensor = tensor_value[ + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ), + :, + ] + else: + tensor = tensor_value[ + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ) + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_size): + """ + This function copies the correct shard of the weights for a rowwise-TP'd module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + if not is_bias: + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ( + (rank + 1) * output_size_per_partition + ), + ] + param.copy_(tensor, non_blocking=True) + else: + if rank == 0: + _copy_if_present(param, tensor_value) + else: + param.zero_() + + +def _copy_embedding(param: torch.nn.Parameter, tensor_value, rank, world_size): + """ + This function copies the correct shard of the weights for a TP'd embedding module + according to the rank of the process and the world_size. + + Args + ==== + param: torch.nn.Parameter + Parameter that has had TP applied + tensor_value: torch.Tensor + tensor that needs sharding + rank: int + Rank of the current process + world_size: int + Total number of TP processes + """ + # Divide the weight matrix along the last dimension. + output_size_per_partition = param.shape[1] + tensor = tensor_value[ + :, + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), + ] + param.copy_(tensor, non_blocking=True) + + +def _copy_if_present(parameter, tensor_value): + parameter.copy_(tensor_value, non_blocking=True) + + +def _load_partial_state_dict( + model: torch.nn.Module, + state_dict, + needs_tp_sharding: bool, + rank=0, + world_size=1, +): + unused_params = [] + for key, tensor_value in state_dict.items(): + target_module = model + # Find where to put the weight and decide whether it needs TP'ing + key_steps = key.split(".") + prefix = "" + key_step = 0 + tp_module = None + # Navigate the model tree to find the module where the parameter is + # located and whether there is a TPModule in the way in case the + # parameter requires sharding + while key_step < len(key_steps) - 1: + try: + target_module = getattr(target_module, key_steps[key_step]) + if key_step > 0: + prefix += "." + prefix += key_steps[key_step] + key_step += 1 + if isinstance(target_module, Iterable): + target_module = target_module[int(key_steps[key_step])] # type: ignore[index] + prefix += "." + key_steps[key_step] + key_step += 1 + if isinstance(target_module, TPModule): + tp_module = target_module + except AttributeError: + unused_params.append(key) + break + + # Check if target_module has the Parameter/buffer + try: + param = getattr(target_module, key_steps[-1]) + + # If TP sharding is not needed, copy the parameter + # into the model + if not needs_tp_sharding or tp_module is None: + _copy_if_present(param, tensor_value) + elif tp_module is not None: + # Handle TP sharding + if key_steps[-2] in tp_module.colwise_param_names(): + _copy_colwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.rowwise_param_names(): + _copy_rowwise( + param, + tensor_value, + key_steps[-1] == "bias", + rank, + world_size, + ) + if key_steps[-2] in tp_module.embedding_param_names(): + _copy_embedding( + param, + tensor_value, + rank, + world_size, + ) + except AttributeError: + unused_params.append(key) diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py new file mode 100644 index 0000000000..3d77db0cb6 --- /dev/null +++ b/optimum/habana/distributed/strategy.py @@ -0,0 +1,112 @@ +from abc import abstractmethod +from typing import Any, List, Mapping + +import torch +import torch.distributed +from torch import nn + +#from optimum.habana.distributed import tp_wrapping + + +class DistributedStrategy: + def __init__(self, from_meta=False): + self.from_meta = from_meta + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + """ + Optionally a distributed strategy may distribute modules that are not + numbered layers + """ + return module + + @abstractmethod + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + """ + Distribute each layer as-appropriate + """ + pass + + +class NotDistributed(DistributedStrategy): + def __init__(self, from_meta=False): + super().__init__(from_meta) + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + return module + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + return block + + +NoOpStrategy = NotDistributed() + + +class DeviceMover(nn.Module): + def __init__(self, module: nn.Module, device): + super().__init__() + self.device = device + # make this wrapper module behave as if it was the wrapped module. + attr = module.__dict__ + attr["module"] = module.to(device) + attr["device"] = device + self.__dict__ = attr + + def forward(self, *args, **kwargs): + device = self.device + args = [ + arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args + ] + kwargs = { + k: ( + kwargs[k].to(device) + if isinstance(kwargs[k], torch.Tensor) + else kwargs[k] + ) + for k in kwargs + } + return self.module(*args, **kwargs) + + +class UniformModelParallelStrategy(DistributedStrategy): + def __init__(self, devices: List[int], num_layers: int, from_meta=False): + super().__init__(from_meta) + num_dev = len(devices) + layers_per_dev = num_layers // num_dev + remainder = num_layers - (layers_per_dev * num_dev) + self.layer_to_device = [0] * num_layers + layer_id = 0 + for dev_idx in range(len(devices)): + for i in range(layers_per_dev): + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + if remainder > 0: + self.layer_to_device[layer_id] = devices[dev_idx] + layer_id = layer_id + 1 + remainder -= 1 + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + device = self.layer_to_device[layer] + if self.from_meta: + # https://github.com/pytorch/pytorch/pull/113647 + block.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(block, device) + return wrapped + + def distribute_module( + self, module: nn.Module, final_layers: bool = False + ) -> nn.Module: + if final_layers: + device = self.layer_to_device[len(self.layer_to_device) - 1] + else: + device = self.layer_to_device[0] + if self.from_meta: + return module.to_empty(device=device) # type: ignore[arg-type] + wrapped = DeviceMover(module, device) + return wrapped + + + diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py new file mode 100644 index 0000000000..8951b2398c --- /dev/null +++ b/optimum/habana/distributed/tensorparallel.py @@ -0,0 +1,229 @@ +# mypy: disable-error-code="method-assign,misc" + +import torch +import torch._inductor.ir as ir +import torch._inductor.lowering as lowering +import torch.distributed as dist +import torch.distributed._functional_collectives as funcol +from torch import nn + + +def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.out_features // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=0)[rank] + ) + if par_mod.bias is not None: + par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) + # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") + + +def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.in_features // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=1)[rank] + ) + if par_mod.bias is not None: + if rank == 0: + par_mod.bias.copy_(mod.bias) + else: + par_mod.bias.zero_() + # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight}, bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") + + +def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank): + # Divide the weight matrix along the last dimension. + output_size_per_partition = mod.embedding_dim // world_size + with torch.no_grad(): + par_mod.weight.copy_( + torch.split(mod.weight, output_size_per_partition, dim=1)[rank] + ) + # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") + + +## Fixes for PT 2.2 collectives until PT 2.3 is released + + +# Fix 1: https://github.com/pytorch/pytorch/issues/121311 +def get_volatile_reads_fixed(self): + inp = self.inputs[0] + if isinstance(inp, ir._CollectiveKernel): + # Out-of-place single-output + return [inp.inputs[0]] + elif isinstance(inp, ir.MultiOutput): + # Out-of-place multi-output + coll = inp.inputs[0] + if isinstance(coll, ir._CollectiveKernel): + _, idx = inp.indices[0] + return [coll.inputs[idx]] + return [] # e.g. regular FallbackKernel + else: + # In-place requires no additional deps handling for volatile + # reads since the inputs are mutated. + return [] + + +ir._WaitKernel.get_volatile_reads = get_volatile_reads_fixed + +# Fix 2: These are fixed already in nightlies and will be in 2.3 +for overload in torch.ops._c10d_functional.all_reduce.overloads(): + other_fn = getattr(torch.ops._c10d_functional.all_reduce, overload) + if other_fn in lowering.lowerings: + del lowering.lowerings[other_fn] + + +@lowering.register_lowering(torch.ops._c10d_functional.all_reduce) +def _all_reduce_fixed(inp, reduce_op, group_name): + inp = torch.clone(inp) + ir._CollectiveKernel.create_inplace( + torch.ops._c10d_functional.all_reduce_.default, + ir.ExternKernel.require_contiguous(inp), + reduce_op, + group_name, + ) + return inp + + +for overload in torch.ops._c10d_functional.all_gather_into_tensor.overloads(): + other_fn = getattr(torch.ops._c10d_functional.all_gather_into_tensor, overload) + if other_fn in lowering.lowerings: + del lowering.lowerings[other_fn] + + +@lowering.register_lowering(torch.ops._c10d_functional.all_gather_into_tensor) +def _all_gather_into_tensor(inp, group_size, group_name): + return ir.TensorBox.create( + ir._CollectiveKernel.create_out_of_place( + torch.ops._c10d_functional.all_gather_into_tensor.default, + ir.ExternKernel.require_contiguous(inp), + group_size, + group_name, + ) + ) + + +def _all_gather(input_: torch.Tensor) -> torch.Tensor: + """Gather the input tensor across model parallel group.""" + world_size = dist.get_world_size() + + if world_size == 1: + return input_ + + # The transposes here are to avoid excessive recompilation due to split() + # specializing the dimension where the all_gather is happening + last_dim = input_.dim() - 1 + # Starting PT 2.3, we can go back to funcol.all_gather_tensor + # TODO SW-180411 WA + # return ( + # torch.ops._c10d_functional.wait_tensor( + # torch.ops._c10d_functional.all_gather_into_tensor( + # input_.transpose(0, last_dim).contiguous(), world_size, "default" + # ) + # ) + # .transpose(0, last_dim) + # .contiguous() + # ) + shape = list(input_.transpose(0, last_dim).size()) + shape[0] *= world_size + output = torch.empty(shape, dtype=input_.dtype, device=input_.device) + dist.all_gather_into_tensor(output, input_.transpose(0, last_dim).contiguous()) + return output.transpose(0, last_dim).contiguous() + + +def _all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + world_size = dist.get_world_size() + + if world_size == 1: + return input_ + + # Starting PT 2.3, we can go back to funcol.all_reduce + return torch.ops._c10d_functional.wait_tensor( + torch.ops._c10d_functional.all_reduce(input_, "sum", "default") + ) + + +def _split(input_: torch.Tensor, rank, world_size) -> torch.Tensor: + """Split the tensor along its last dimension and keep the + corresponding slice.""" + + if world_size == 1: + return input_ + + # Split along last dimension. + # Get the size and dimension. + last_dim = input_.dim() - 1 + last_dim_size = input_.size()[last_dim] // world_size + # Split. + input_list = torch.split(input_, last_dim_size, dim=last_dim) + + # Note: torch.split does not create contiguous tensors by default. + output = input_list[rank].contiguous() + + return output + + +class _CopyToModelParallelRegion(torch.autograd.Function): + """Pass the input to the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_): + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output) + + +class _ReduceFromModelParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_): + return _all_reduce(input_) + + @staticmethod + def backward(ctx, grad_output): + return grad_output + + +class _AllGatherFromModelParallelRegion(torch.autograd.Function): + """Gather the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _all_gather(input_) + + @staticmethod + def forward(ctx, input_, rank, world_size): + ctx.rank = rank + ctx.world_size = world_size + return _all_gather(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.rank, ctx.world_size) + + +def copy_to_tensor_model_parallel_region(input_): + return _CopyToModelParallelRegion.apply(input_) + + +def reduce_from_tensor_model_parallel_region(input_): + return _ReduceFromModelParallelRegion.apply(input_) + + +def all_gather_from_tensor_model_parallel_region(input_, rank, world_size): + return _AllGatherFromModelParallelRegion.apply(input_, rank, world_size) diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py new file mode 100644 index 0000000000..31f33a79cc --- /dev/null +++ b/optimum/habana/distributed/tp.py @@ -0,0 +1,84 @@ +import itertools +from abc import ABCMeta, abstractmethod +from typing import List, Type + +import torch +import torch.nn as nn +from torch.distributed.distributed_c10d import ProcessGroup + +from optimum.habana.distributed.tensorparallel import ( + apply_colwise_tp, + apply_embedding_tp, + apply_rowwise_tp, +) + + +class TPModule(nn.Module, metaclass=ABCMeta): + """ + This is an abstract class that any nn.Module can implement to enable + Tensor Parallel. On top of inheriting from this class, the TP module + will have to implement list_colwise_weights, list_rowwise_weights, + list_embedding_weights, and import_module for their relevant weights. + Finally, the module must call setup_tp at the end of their __init__ + function. See examples in attention.py, feedforward.py and embedding.py + + """ + + rank: int + world_size: int + + def setup_tp(self, rank: int, world_size: int) -> None: + self.rank = rank + self.world_size = world_size + + def colwise_param_names(self) -> List[str]: + return [] + + def rowwise_param_names(self) -> List[str]: + return [] + + def embedding_param_names(self) -> List[str]: + return [] + + @staticmethod + @abstractmethod + def import_module(module, group: ProcessGroup): + pass + + def import_weights(self, module: nn.Module): + for weight in self.colwise_param_names(): + apply_colwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.rowwise_param_names(): + apply_rowwise_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + for weight in self.embedding_param_names(): + apply_embedding_tp( + getattr(self, weight), + getattr(module, weight), + self.world_size, + self.rank, + ) + tp_sharded_modules = list( + itertools.chain( + self.colwise_param_names(), + self.rowwise_param_names(), + self.embedding_param_names(), + ) + ) + with torch.no_grad(): + for mod_name, module in self.named_children(): + if not mod_name in tp_sharded_modules: + for param_name, param in module.named_parameters(recurse=False): + param.copy_( + getattr(getattr(module, mod_name), param_name), + non_blocking=True, + ) diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py new file mode 100644 index 0000000000..402accb342 --- /dev/null +++ b/optimum/habana/distributed/tp_wrapping.py @@ -0,0 +1,33 @@ +import os +from torch import nn +from torch.distributed.distributed_c10d import ProcessGroup + +from optimum.habana.transformers.models.llama.modeling_llama import ( + GaudiLlamaMLP, + TPGaudiLlamaMLP, + GaudiLlamaAttention, + TPGaudiLlamaAttention +) + +# this probably belongs somewhere else but can't go in fms.distribtued b/c +# circular dependency. +def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): + if hasattr(module, "to_tp"): + return module.to_tp(group) + elif isinstance(module, GaudiLlamaAttention): + return TPGaudiLlamaAttention.import_module(module,layer, group) + elif isinstance(module, GaudiLlamaMLP): + return TPGaudiLlamaMLP.import_module(module, group) + else: + return module + + +def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): + wrapped = _tp_wrapped(model, layer_idx, group) + if wrapped is not model: + return wrapped + + for name, layer in model.named_children(): + tp_layer = apply_tp(layer, layer_idx, group) + setattr(model, name, tp_layer) + return model diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 1cbd714df5..cd2d148e16 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -19,11 +19,22 @@ apply_rotary_pos_emb, logger, ) - +import copy +from torch.distributed.distributed_c10d import ProcessGroup from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) +from optimum.habana.distributed.tp import TPModule +from optimum.habana import distributed +from optimum.habana.distributed.tensorparallel import ( + copy_to_tensor_model_parallel_region, + reduce_from_tensor_model_parallel_region, +) + +from optimum.habana.distributed.strategy import DistributedStrategy +from optimum.habana.distributed.strategy import NotDistributed +NoOpStrategy = NotDistributed() try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -188,6 +199,46 @@ def post_mlp_forward(self, x): return self.down_proj.post_all_reduce(x) return x +class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): + def __init__( + self, + config, + group: Optional[ProcessGroup] = None, + ): + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + hidden_dim = int(config.hidden_grow_factor * config.hidden_size) + assert ( + hidden_dim % world_size == 0 + ), "Hidden dim must be divisible by world size" + + self.config = copy.deepcopy(config) + self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) + GaudiLlamaMLP.__init__( + self, + self.config + ) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + return ["up_proj", "gate_proj"] + + def rowwise_param_names(self) -> List[str]: + return ["down_proj"] + + @staticmethod + def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": + config = copy.deepcopy(glu.config) + config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size + tp_glu = TPGaudiLlamaMLP( + config = config, + group=group + ) + return tp_glu + + def pre_mlp_forward(self, x): + out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) + return reduce_from_tensor_model_parallel_region(out_par) def gaudi_llama_repeat_kv( query_states: torch.Tensor, @@ -545,6 +596,94 @@ def post_attn_forward(self, attn_output): return attn_output +class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, group: Optional[ProcessGroup] = None,): + super().__init__(config, layer_idx) + + assert torch.distributed.is_initialized() + rank, world_size = distributed.rank_and_world(group) + assert ( + config.num_attention_heads % world_size == 0 + ), "The number of heads must be divisible by world size" + self.config = copy.deepcopy(config) + + self.pre_tp_kvheads = config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config , layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( self.config.num_key_value_heads // world_size) if self.config.num_key_value_heads > 1 else self.config.num_key_value_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.hidden_size = self.config.hidden_size // world_size + self.num_heads = self.config.num_attention_heads + + self.q_proj = torch.nn.Linear(config.hidden_size, self.config.num_attention_heads * self.head_dim , bias=config.attention_bias) + self.k_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) + self.v_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) + self.o_proj = torch.nn.Linear(self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.setup_tp(rank, world_size) + + def colwise_param_names(self) -> List[str]: + colwise_weights = ["q_proj"] + if self.pre_tp_kvheads != 1: + colwise_weights.append("k_proj") + colwise_weights.append("v_proj") + return colwise_weights + + def rowwise_param_names(self) -> List[str]: + return ["o_proj"] + + @staticmethod + def import_module( + mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup + ) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention( + config = mha.config, + layer_idx=layer_idx, + group=group + ) + return tp_mha + + def pre_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + cache_idx: int = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + token_idx, + attn_softmax_bf16, + reuse_cache, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + cache_idx, + **kwargs + ) + + hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) + return hidden_states, attn_weights, present_key_value class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -716,17 +855,23 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - + self.distributed_strategy = config.distributed_strategy + config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = torch.nn.ModuleList( - [GaudiLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) + layers = [] + for i in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, i) + layer = self.distributed_strategy.distribute_layer(layer, i) + layers.append(layer) + self.layers = torch.nn.ModuleList(layers) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -954,7 +1099,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args attn_softmax_bf16 - add new args reuse_cache """ - + def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): + config.distributed_strategy = distributed_strategy + super().__init__(config) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -998,6 +1146,7 @@ def forward( global has_fused_rope has_fused_rope = False + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, From 4feb0850cf5dd1f2d86039aac51e48c44b3f436f Mon Sep 17 00:00:00 2001 From: Kalyan Date: Mon, 15 Jul 2024 09:46:46 +0300 Subject: [PATCH 02/10] Code cleanup -removed unused code --- examples/text-generation/utils.py | 3 + optimum/habana/distributed/strategy.py | 4 - optimum/habana/distributed/tensorparallel.py | 126 ------------------ .../models/llama/modeling_llama.py | 14 +- 4 files changed, 10 insertions(+), 137 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 494ffc11d5..c38f47821f 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -43,6 +43,7 @@ set_seed, ) + def adjust_batch(batch, size): curr_size = batch["input_ids"].shape[1] if curr_size >= size: @@ -290,6 +291,7 @@ def __setstate__(self, state): dist.init_process_group() torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) + logger.info("Creating Model") config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) model_kwargs={} model_kwargs["distributed_strategy"] = TensorParallelStrategy() @@ -299,6 +301,7 @@ def __setstate__(self, state): source="hf" checkpoint_sharding=None lazy_sd: MutableMapping[str, Any] = {} + logger.info("Loading Checkpoints") lazy_sd = serialization.load_state_dict( args.model_name_or_path, source=source, diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py index 3d77db0cb6..4bc68bbca7 100644 --- a/optimum/habana/distributed/strategy.py +++ b/optimum/habana/distributed/strategy.py @@ -5,9 +5,6 @@ import torch.distributed from torch import nn -#from optimum.habana.distributed import tp_wrapping - - class DistributedStrategy: def __init__(self, from_meta=False): self.from_meta = from_meta @@ -91,7 +88,6 @@ def __init__(self, devices: List[int], num_layers: int, from_meta=False): def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: device = self.layer_to_device[layer] if self.from_meta: - # https://github.com/pytorch/pytorch/pull/113647 block.to_empty(device=device) # type: ignore[arg-type] wrapped = DeviceMover(block, device) return wrapped diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py index 8951b2398c..5d484b2fc5 100644 --- a/optimum/habana/distributed/tensorparallel.py +++ b/optimum/habana/distributed/tensorparallel.py @@ -17,7 +17,6 @@ def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): ) if par_mod.bias is not None: par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): @@ -32,7 +31,6 @@ def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): par_mod.bias.copy_(mod.bias) else: par_mod.bias.zero_() - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight}, bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, rank): @@ -42,7 +40,6 @@ def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, ran par_mod.weight.copy_( torch.split(mod.weight, output_size_per_partition, dim=1)[rank] ) - # print(f"For rank {rank}, we have the following weights: Base weight {mod.weight} bias {mod.bias}; Par weight {par_mod.weight}, bias {par_mod.bias}") ## Fixes for PT 2.2 collectives until PT 2.3 is released @@ -75,65 +72,6 @@ def get_volatile_reads_fixed(self): if other_fn in lowering.lowerings: del lowering.lowerings[other_fn] - -@lowering.register_lowering(torch.ops._c10d_functional.all_reduce) -def _all_reduce_fixed(inp, reduce_op, group_name): - inp = torch.clone(inp) - ir._CollectiveKernel.create_inplace( - torch.ops._c10d_functional.all_reduce_.default, - ir.ExternKernel.require_contiguous(inp), - reduce_op, - group_name, - ) - return inp - - -for overload in torch.ops._c10d_functional.all_gather_into_tensor.overloads(): - other_fn = getattr(torch.ops._c10d_functional.all_gather_into_tensor, overload) - if other_fn in lowering.lowerings: - del lowering.lowerings[other_fn] - - -@lowering.register_lowering(torch.ops._c10d_functional.all_gather_into_tensor) -def _all_gather_into_tensor(inp, group_size, group_name): - return ir.TensorBox.create( - ir._CollectiveKernel.create_out_of_place( - torch.ops._c10d_functional.all_gather_into_tensor.default, - ir.ExternKernel.require_contiguous(inp), - group_size, - group_name, - ) - ) - - -def _all_gather(input_: torch.Tensor) -> torch.Tensor: - """Gather the input tensor across model parallel group.""" - world_size = dist.get_world_size() - - if world_size == 1: - return input_ - - # The transposes here are to avoid excessive recompilation due to split() - # specializing the dimension where the all_gather is happening - last_dim = input_.dim() - 1 - # Starting PT 2.3, we can go back to funcol.all_gather_tensor - # TODO SW-180411 WA - # return ( - # torch.ops._c10d_functional.wait_tensor( - # torch.ops._c10d_functional.all_gather_into_tensor( - # input_.transpose(0, last_dim).contiguous(), world_size, "default" - # ) - # ) - # .transpose(0, last_dim) - # .contiguous() - # ) - shape = list(input_.transpose(0, last_dim).size()) - shape[0] *= world_size - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - dist.all_gather_into_tensor(output, input_.transpose(0, last_dim).contiguous()) - return output.transpose(0, last_dim).contiguous() - - def _all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" world_size = dist.get_world_size() @@ -146,43 +84,6 @@ def _all_reduce(input_: torch.Tensor) -> torch.Tensor: torch.ops._c10d_functional.all_reduce(input_, "sum", "default") ) - -def _split(input_: torch.Tensor, rank, world_size) -> torch.Tensor: - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - if world_size == 1: - return input_ - - # Split along last dimension. - # Get the size and dimension. - last_dim = input_.dim() - 1 - last_dim_size = input_.size()[last_dim] // world_size - # Split. - input_list = torch.split(input_, last_dim_size, dim=last_dim) - - # Note: torch.split does not create contiguous tensors by default. - output = input_list[rank].contiguous() - - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _all_reduce(grad_output) - - class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -198,32 +99,5 @@ def forward(ctx, input_): def backward(ctx, grad_output): return grad_output - -class _AllGatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from the model parallel region.""" - - @staticmethod - def symbolic(graph, input_): - return _all_gather(input_) - - @staticmethod - def forward(ctx, input_, rank, world_size): - ctx.rank = rank - ctx.world_size = world_size - return _all_gather(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.rank, ctx.world_size) - - -def copy_to_tensor_model_parallel_region(input_): - return _CopyToModelParallelRegion.apply(input_) - - def reduce_from_tensor_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) - - -def all_gather_from_tensor_model_parallel_region(input_, rank, world_size): - return _AllGatherFromModelParallelRegion.apply(input_, rank, world_size) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index cd2d148e16..d7244c632b 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -19,6 +19,7 @@ apply_rotary_pos_emb, logger, ) + import copy from torch.distributed.distributed_c10d import ProcessGroup from ...modeling_attn_mask_utils import ( @@ -28,13 +29,11 @@ from optimum.habana.distributed.tp import TPModule from optimum.habana import distributed from optimum.habana.distributed.tensorparallel import ( - copy_to_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, ) from optimum.habana.distributed.strategy import DistributedStrategy -from optimum.habana.distributed.strategy import NotDistributed -NoOpStrategy = NotDistributed() +from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -855,13 +854,14 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.distributed_strategy = config.distributed_strategy config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) layers = [] - for i in range(config.num_hidden_layers): - layer = GaudiLlamaDecoderLayer(config, i) - layer = self.distributed_strategy.distribute_layer(layer, i) + for layer_idx in range(config.num_hidden_layers): + layer = GaudiLlamaDecoderLayer(config, layer_idx) + layer = self.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) @@ -871,7 +871,6 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -1099,6 +1098,7 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args attn_softmax_bf16 - add new args reuse_cache """ + def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): config.distributed_strategy = distributed_strategy super().__init__(config) From a73c6f67b2b49bc59c5db42b2e5e6959bb76b2d6 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Wed, 17 Jul 2024 18:54:33 +0300 Subject: [PATCH 03/10] make style, updated README for distributed_strategy="tp" Added test in tests/test_text_generation_example.py add a link to the original implementation for the referenced files --- examples/text-generation/README.md | 29 +++++ examples/text-generation/utils.py | 65 +++-------- optimum/habana/distributed/__init__.py | 7 +- optimum/habana/distributed/serialization.py | 78 ++++++------- optimum/habana/distributed/strategy.py | 68 ++++++++---- optimum/habana/distributed/tensorparallel.py | 37 ++++--- optimum/habana/distributed/tp.py | 23 +++- optimum/habana/distributed/tp_wrapping.py | 33 ++++-- .../models/llama/modeling_llama.py | 103 +++++++++--------- tests/test_text_generation_example.py | 33 +++++- 10 files changed, 281 insertions(+), 195 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 440b18713c..efee8aac45 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,6 +264,35 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. +### Running with Tesor parallel strategy +#### Attribution + +This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. + +torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable +torch.compile with tensor parallel strategy, please set the following environment variables before running the +command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. + +You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in your command. + +Here is an example: +```bash +python ../gaudi_spawn.py --world_size 8 run_generation.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--trim_logits \ +--use_kv_cache \ +--attn_softmax_bf16 \ +--bf16 \ +--bucket_internal \ +--bucket_size=128 \ +--use_flash_attention \ +--flash_attention_recompute \ +--batch_size 246 \ +--max_input_tokens 2048 \ +--max_new_tokens 2048 \ +--torch_compile \ +--distributed_strategy="tp" +``` ### Running with FP8 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c38f47821f..0c9d86cfc7 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -247,59 +247,35 @@ def setup_model(args, model_dtype, model_kwargs, logger): # assistant_model = get_torch_compiled_model(assistant_model) return model, assistant_model -def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): - from optimum.habana.distributed import serialization +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): from typing import Any, MutableMapping - from optimum.habana.distributed import tp_wrapping - from optimum.habana.distributed.strategy import DistributedStrategy - from torch import nn - - class TensorParallelStrategy(DistributedStrategy): - def __init__(self, group=None, from_meta=False): - super().__init__(from_meta) - assert torch.distributed.is_initialized(), "must initialize a process group" - self.group = group if group is not None else torch.distributed.GroupMember.WORLD - - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: - return tp_wrapping.apply_tp(module, self.group) - - def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: - return tp_wrapping.apply_tp(block, layer, self.group) - - def __getstate__(self): - state = self.__dict__.copy() - state['group'] = None # Remove ProcessGroup from state - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.group = None # Restore to default state or reinitialize + from optimum.habana.distributed import serialization + from optimum.habana.distributed.strategy import TensorParallelStrategy logger.info("Multi-device run.") + assert args.quant_config == "", "Fp8 is not enabled, unset QUANT_CONFIG" assert args.assistant_model is None, "Assistant model must be None" - + from torch import distributed as dist - if args.device == 'hpu': - import habana_frameworks.torch.distributed.hccl - dist.init_process_group(backend='hccl') + + if args.device == "hpu": + dist.init_process_group(backend="hccl") else: - dist.init_process_group() - + assert False, "Supports TP only on HPU" + torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD) logger.info("Creating Model") - config = AutoConfig.from_pretrained(args.model_name_or_path,torch_dtype=model_dtype, **model_kwargs) - model_kwargs={} + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + model_kwargs = {} model_kwargs["distributed_strategy"] = TensorParallelStrategy() model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) initial_device = torch.device("cpu") - source="hf" - checkpoint_sharding=None + source = "hf" + checkpoint_sharding = None lazy_sd: MutableMapping[str, Any] = {} logger.info("Loading Checkpoints") lazy_sd = serialization.load_state_dict( @@ -311,7 +287,7 @@ def __setstate__(self, state): rank=args.global_rank, world_size=args.world_size, ) - architecture="llama" + architecture = "llama" if len(lazy_sd): serialization.load_state_dict_into_model( model, @@ -325,18 +301,12 @@ def __setstate__(self, state): args.world_size, ) - if args.quant_config: - model = setup_quantization(model, args) - model = model.eval().to(args.device) if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph - if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": - model = wrap_in_hpu_graph(model, hash_with_views=False) - else: - model = wrap_in_hpu_graph(model) + model = wrap_in_hpu_graph(model) if args.torch_compile and model.config.model_type == "llama": model = get_torch_compiled_model(model) @@ -617,7 +587,8 @@ def initialize_model(args, logger): model, assistant_model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed - else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.distributed_strategy == "tp" + else setup_distributed_model(args, model_dtype, model_kwargs, logger) + if not args.distributed_strategy == "tp" else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) diff --git a/optimum/habana/distributed/__init__.py b/optimum/habana/distributed/__init__.py index 12edd6620e..af269ee68c 100644 --- a/optimum/habana/distributed/__init__.py +++ b/optimum/habana/distributed/__init__.py @@ -1,8 +1,11 @@ -from .distributed_runner import DistributedRunner -from .fast_ddp import all_reduce_gradients import os + import torch +from .distributed_runner import DistributedRunner +from .fast_ddp import all_reduce_gradients + + def rank_and_world(group=None): """ Returns (rank, world_size) from the optionally-specified group, otherwise diff --git a/optimum/habana/distributed/serialization.py b/optimum/habana/distributed/serialization.py index c543ab20bd..bf59fb2445 100644 --- a/optimum/habana/distributed/serialization.py +++ b/optimum/habana/distributed/serialization.py @@ -1,3 +1,20 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + import collections import os from collections import ChainMap @@ -7,7 +24,7 @@ import torch -from optimum.habana.distributed.tp import TPModule +from .tp import TPModule __adapters: MutableMapping[str, MutableMapping[str, Callable[[Mapping], Mapping]]] = {} @@ -34,9 +51,7 @@ def register_adapter( sources = __adapters[architecture] if source in sources: - raise KeyError( - f"Variant {source} already registered for architecture {architecture}" - ) + raise KeyError(f"Variant {source} already registered for architecture {architecture}") sources[source] = adapter __adapters[architecture] = sources @@ -55,14 +70,8 @@ def list_sources(architecture: str): return list(__adapters[architecture].keys()) -def _get_adapter( - architecture: str, source: Optional[str] -) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: - if ( - source is None - or architecture not in __adapters - or source not in __adapters[architecture] - ): +def _get_adapter(architecture: str, source: Optional[str]) -> Callable[[Mapping[str, Any]], Mapping[str, Any]]: + if source is None or architecture not in __adapters or source not in __adapters[architecture]: # if no adapter is registered, assume the attributes are already in # fms format. # should we raise an error here instead? @@ -71,9 +80,7 @@ def _get_adapter( return __adapters[architecture][source] -def get_adapted( - architecture: str, source: Optional[str], state_dict: Mapping[str, Any] -) -> Mapping[str, Any]: +def get_adapted(architecture: str, source: Optional[str], state_dict: Mapping[str, Any]) -> Mapping[str, Any]: """ Convert a state dict to FMS format, using an adapter specified by name. @@ -91,18 +98,11 @@ def get_adapted( return adapted -# `models` imports each model class, causing models and adapters to be registered. -# down here to avoid circular dependencies. -# from fms import models - - def _get_safetensors_item(key, file: Path, device: torch.device) -> torch.Tensor: from safetensors import safe_open # type: ignore[import-untyped] with torch.no_grad(): - with safe_open( - file, framework="pt", device=str(device) - ) as model_weights: # type: ignore[attr-defined] + with safe_open(file, framework="pt", device=str(device)) as model_weights: # type: ignore[attr-defined] return model_weights.get_tensor(key) @@ -153,7 +153,7 @@ def load_state_dict( if model_path is None or initial_device.type == "meta": return {} if checkpoint_sharding == "fsdp" and distributed_strategy not in ["fsdp", "hsdp"]: - raise ValueError(f"FSDP checkpoints can only be loaded into an FSDP model") + raise ValueError("FSDP checkpoints can only be loaded into an FSDP model") if checkpoint_sharding == "tp" and distributed_strategy != "tp": raise ValueError("TP checkpoints can only be loaded into a TP model") @@ -188,13 +188,11 @@ def load_state_dict( checkpoints = [model_path] # Check if we found some files - assert ( - len(checkpoints) > 0 - ), f"Can't find the requested checkpoint data at {model_path}" + assert len(checkpoints) > 0, f"Can't find the requested checkpoint data at {model_path}" if checkpoint_sharding is not None and checkpoint_sharding != "layer": - assert world_size == len( - checkpoints + assert ( + world_size == len(checkpoints) ), f"Loading a {checkpoint_sharding}-sharded checkpoint with len={len(checkpoints)} but world size is {world_size}" checkpoints = [checkpoints[rank]] @@ -304,13 +302,9 @@ def load_state_dict_into_model( used_keys.add(weight) partial_sd[weight] = state_dict[weight] if partial_sd[weight].device != initial_device: - partial_sd[weight] = partial_sd[weight].to( - device=initial_device - ) + partial_sd[weight] = partial_sd[weight].to(device=initial_device) fms_partial_sd = adapter(partial_sd) - _load_partial_state_dict( - model, fms_partial_sd, needs_tp_sharding, rank, world_size - ) + _load_partial_state_dict(model, fms_partial_sd, needs_tp_sharding, rank, world_size) for p_key in partial_sd.keys(): if isinstance(state_dict, ChainMap): for child_sd in state_dict.maps: @@ -341,17 +335,11 @@ def _copy_colwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_ output_size_per_partition = param.shape[0] if not is_bias: tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), :, ] else: - tensor = tensor_value[ - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ) - ] + tensor = tensor_value[(rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition)] param.copy_(tensor, non_blocking=True) @@ -376,9 +364,7 @@ def _copy_rowwise(param: torch.nn.Parameter, tensor_value, is_bias, rank, world_ output_size_per_partition = param.shape[1] tensor = tensor_value[ :, - (rank * output_size_per_partition) : ( - (rank + 1) * output_size_per_partition - ), + (rank * output_size_per_partition) : ((rank + 1) * output_size_per_partition), ] param.copy_(tensor, non_blocking=True) else: diff --git a/optimum/habana/distributed/strategy.py b/optimum/habana/distributed/strategy.py index 4bc68bbca7..91b3f00232 100644 --- a/optimum/habana/distributed/strategy.py +++ b/optimum/habana/distributed/strategy.py @@ -1,17 +1,33 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + from abc import abstractmethod -from typing import Any, List, Mapping +from typing import List import torch import torch.distributed from torch import nn + class DistributedStrategy: def __init__(self, from_meta=False): self.from_meta = from_meta - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: """ Optionally a distributed strategy may distribute modules that are not numbered layers @@ -30,9 +46,7 @@ class NotDistributed(DistributedStrategy): def __init__(self, from_meta=False): super().__init__(from_meta) - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: return module def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: @@ -54,17 +68,8 @@ def __init__(self, module: nn.Module, device): def forward(self, *args, **kwargs): device = self.device - args = [ - arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args - ] - kwargs = { - k: ( - kwargs[k].to(device) - if isinstance(kwargs[k], torch.Tensor) - else kwargs[k] - ) - for k in kwargs - } + args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: (kwargs[k].to(device) if isinstance(kwargs[k], torch.Tensor) else kwargs[k]) for k in kwargs} return self.module(*args, **kwargs) @@ -92,9 +97,7 @@ def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: wrapped = DeviceMover(block, device) return wrapped - def distribute_module( - self, module: nn.Module, final_layers: bool = False - ) -> nn.Module: + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: if final_layers: device = self.layer_to_device[len(self.layer_to_device) - 1] else: @@ -105,4 +108,27 @@ def distribute_module( return wrapped +class TensorParallelStrategy(DistributedStrategy): + def __init__(self, group=None, from_meta=False): + super().__init__(from_meta) + assert torch.distributed.is_initialized(), "must initialize a process group" + self.group = group if group is not None else torch.distributed.GroupMember.WORLD + + def distribute_module(self, module: nn.Module, final_layers: bool = False) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(module, self.group) + + def distribute_layer(self, block: nn.Module, layer: int) -> nn.Module: + from optimum.habana.distributed import tp_wrapping + + return tp_wrapping.apply_tp(block, layer, self.group) + + def __getstate__(self): + state = self.__dict__.copy() + state["group"] = None # Remove ProcessGroup from state + return state + def __setstate__(self, state): + self.__dict__.update(state) + self.group = None # Restore to default state or reinitialize diff --git a/optimum/habana/distributed/tensorparallel.py b/optimum/habana/distributed/tensorparallel.py index 5d484b2fc5..09a205e451 100644 --- a/optimum/habana/distributed/tensorparallel.py +++ b/optimum/habana/distributed/tensorparallel.py @@ -1,10 +1,24 @@ -# mypy: disable-error-code="method-assign,misc" +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack import torch import torch._inductor.ir as ir import torch._inductor.lowering as lowering import torch.distributed as dist -import torch.distributed._functional_collectives as funcol from torch import nn @@ -12,9 +26,7 @@ def apply_colwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): # Divide the weight matrix along the last dimension. output_size_per_partition = mod.out_features // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=0)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=0)[rank]) if par_mod.bias is not None: par_mod.bias.copy_(torch.split(mod.bias, output_size_per_partition)[rank]) @@ -23,9 +35,7 @@ def apply_rowwise_tp(par_mod: nn.Linear, mod: nn.Linear, world_size, rank): # Divide the weight matrix along the last dimension. output_size_per_partition = mod.in_features // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=1)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) if par_mod.bias is not None: if rank == 0: par_mod.bias.copy_(mod.bias) @@ -37,9 +47,7 @@ def apply_embedding_tp(par_mod: nn.Embedding, mod: nn.Embedding, world_size, ran # Divide the weight matrix along the last dimension. output_size_per_partition = mod.embedding_dim // world_size with torch.no_grad(): - par_mod.weight.copy_( - torch.split(mod.weight, output_size_per_partition, dim=1)[rank] - ) + par_mod.weight.copy_(torch.split(mod.weight, output_size_per_partition, dim=1)[rank]) ## Fixes for PT 2.2 collectives until PT 2.3 is released @@ -72,6 +80,7 @@ def get_volatile_reads_fixed(self): if other_fn in lowering.lowerings: del lowering.lowerings[other_fn] + def _all_reduce(input_: torch.Tensor) -> torch.Tensor: """All-reduce the input tensor across model parallel group.""" world_size = dist.get_world_size() @@ -80,9 +89,8 @@ def _all_reduce(input_: torch.Tensor) -> torch.Tensor: return input_ # Starting PT 2.3, we can go back to funcol.all_reduce - return torch.ops._c10d_functional.wait_tensor( - torch.ops._c10d_functional.all_reduce(input_, "sum", "default") - ) + return torch.ops._c10d_functional.wait_tensor(torch.ops._c10d_functional.all_reduce(input_, "sum", "default")) + class _ReduceFromModelParallelRegion(torch.autograd.Function): """All-reduce the input from the model parallel region.""" @@ -99,5 +107,6 @@ def forward(ctx, input_): def backward(ctx, grad_output): return grad_output + def reduce_from_tensor_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) diff --git a/optimum/habana/distributed/tp.py b/optimum/habana/distributed/tp.py index 31f33a79cc..c4f156fa61 100644 --- a/optimum/habana/distributed/tp.py +++ b/optimum/habana/distributed/tp.py @@ -1,12 +1,29 @@ +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + import itertools from abc import ABCMeta, abstractmethod -from typing import List, Type +from typing import List import torch import torch.nn as nn from torch.distributed.distributed_c10d import ProcessGroup -from optimum.habana.distributed.tensorparallel import ( +from .tensorparallel import ( apply_colwise_tp, apply_embedding_tp, apply_rowwise_tp, @@ -76,7 +93,7 @@ def import_weights(self, module: nn.Module): ) with torch.no_grad(): for mod_name, module in self.named_children(): - if not mod_name in tp_sharded_modules: + if mod_name not in tp_sharded_modules: for param_name, param in module.named_parameters(recurse=False): param.copy_( getattr(getattr(module, mod_name), param_name), diff --git a/optimum/habana/distributed/tp_wrapping.py b/optimum/habana/distributed/tp_wrapping.py index 402accb342..761fa7bff4 100644 --- a/optimum/habana/distributed/tp_wrapping.py +++ b/optimum/habana/distributed/tp_wrapping.py @@ -1,21 +1,36 @@ -import os +# Copyright 2024 The Foundation Model Stack Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file has been modified from its original version. +# The original version can be found at https://github.com/foundation-model-stack/foundation-model-stack + from torch import nn from torch.distributed.distributed_c10d import ProcessGroup -from optimum.habana.transformers.models.llama.modeling_llama import ( - GaudiLlamaMLP, - TPGaudiLlamaMLP, +from ..transformers.models.llama.modeling_llama import ( GaudiLlamaAttention, - TPGaudiLlamaAttention + GaudiLlamaMLP, + TPGaudiLlamaAttention, + TPGaudiLlamaMLP, ) -# this probably belongs somewhere else but can't go in fms.distribtued b/c -# circular dependency. + def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): if hasattr(module, "to_tp"): return module.to_tp(group) elif isinstance(module, GaudiLlamaAttention): - return TPGaudiLlamaAttention.import_module(module,layer, group) + return TPGaudiLlamaAttention.import_module(module, layer, group) elif isinstance(module, GaudiLlamaMLP): return TPGaudiLlamaMLP.import_module(module, group) else: @@ -23,7 +38,7 @@ def _tp_wrapped(module: nn.Module, layer: int, group: ProcessGroup): def apply_tp(model: nn.Module, layer_idx: int, group: ProcessGroup): - wrapped = _tp_wrapped(model, layer_idx, group) + wrapped = _tp_wrapped(model, layer_idx, group) if wrapped is not model: return wrapped diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d7244c632b..c071dea0e2 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -1,3 +1,4 @@ +import copy import math import os import warnings @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F +from torch.distributed.distributed_c10d import ProcessGroup from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -20,20 +22,16 @@ logger, ) -import copy -from torch.distributed.distributed_c10d import ProcessGroup +from .... import distributed +from ....distributed.strategy import DistributedStrategy, NoOpStrategy +from ....distributed.tensorparallel import ( + reduce_from_tensor_model_parallel_region, +) +from ....distributed.tp import TPModule from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) -from optimum.habana.distributed.tp import TPModule -from optimum.habana import distributed -from optimum.habana.distributed.tensorparallel import ( - reduce_from_tensor_model_parallel_region, -) - -from optimum.habana.distributed.strategy import DistributedStrategy -from optimum.habana.distributed.strategy import NoOpStrategy try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -198,6 +196,7 @@ def post_mlp_forward(self, x): return self.down_proj.post_all_reduce(x) return x + class TPGaudiLlamaMLP(GaudiLlamaMLP, TPModule): def __init__( self, @@ -207,16 +206,11 @@ def __init__( assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) hidden_dim = int(config.hidden_grow_factor * config.hidden_size) - assert ( - hidden_dim % world_size == 0 - ), "Hidden dim must be divisible by world size" + assert hidden_dim % world_size == 0, "Hidden dim must be divisible by world size" self.config = copy.deepcopy(config) self.config.intermediate_size = int((config.hidden_grow_factor / world_size) * config.hidden_size) - GaudiLlamaMLP.__init__( - self, - self.config - ) + GaudiLlamaMLP.__init__(self, self.config) self.setup_tp(rank, world_size) def colwise_param_names(self) -> List[str]: @@ -229,16 +223,14 @@ def rowwise_param_names(self) -> List[str]: def import_module(glu: GaudiLlamaMLP, group: ProcessGroup) -> "TPGaudiLlamaMLP": config = copy.deepcopy(glu.config) config.hidden_grow_factor = glu.config.intermediate_size / glu.config.hidden_size - tp_glu = TPGaudiLlamaMLP( - config = config, - group=group - ) + tp_glu = TPGaudiLlamaMLP(config=config, group=group) return tp_glu def pre_mlp_forward(self, x): out_par = GaudiLlamaMLP.pre_mlp_forward(self, x) return reduce_from_tensor_model_parallel_region(out_par) + def gaudi_llama_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -596,28 +588,43 @@ def post_attn_forward(self, attn_output): class TPGaudiLlamaAttention(GaudiLlamaAttention, TPModule): - def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, group: Optional[ProcessGroup] = None,): + def __init__( + self, + config: LlamaConfig, + layer_idx: Optional[int] = None, + group: Optional[ProcessGroup] = None, + ): super().__init__(config, layer_idx) assert torch.distributed.is_initialized() rank, world_size = distributed.rank_and_world(group) - assert ( - config.num_attention_heads % world_size == 0 - ), "The number of heads must be divisible by world size" + assert config.num_attention_heads % world_size == 0, "The number of heads must be divisible by world size" self.config = copy.deepcopy(config) self.pre_tp_kvheads = config.num_key_value_heads - GaudiLlamaAttention.__init__(self, self.config , layer_idx) - self.config.num_attention_heads = self.config.num_attention_heads // world_size - self.config.num_key_value_heads = ( self.config.num_key_value_heads // world_size) if self.config.num_key_value_heads > 1 else self.config.num_key_value_heads + GaudiLlamaAttention.__init__(self, self.config, layer_idx) + self.config.num_attention_heads = self.config.num_attention_heads // world_size + self.config.num_key_value_heads = ( + (self.config.num_key_value_heads // world_size) + if self.config.num_key_value_heads > 1 + else self.config.num_key_value_heads + ) self.head_dim = config.hidden_size // config.num_attention_heads self.hidden_size = self.config.hidden_size // world_size self.num_heads = self.config.num_attention_heads - self.q_proj = torch.nn.Linear(config.hidden_size, self.config.num_attention_heads * self.head_dim , bias=config.attention_bias) - self.k_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.v_proj = torch.nn.Linear(config.hidden_size, self.config.num_key_value_heads * self.head_dim , bias=config.attention_bias) - self.o_proj = torch.nn.Linear(self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias) + self.q_proj = torch.nn.Linear( + config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj = torch.nn.Linear( + config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = torch.nn.Linear( + self.config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) self.norm_factor = 1.0 / math.sqrt(self.head_dim) self.setup_tp(rank, world_size) @@ -632,14 +639,8 @@ def rowwise_param_names(self) -> List[str]: return ["o_proj"] @staticmethod - def import_module( - mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup - ) -> "TPGaudiLlamaAttention": - tp_mha = TPGaudiLlamaAttention( - config = mha.config, - layer_idx=layer_idx, - group=group - ) + def import_module(mha: GaudiLlamaAttention, layer_idx, group: ProcessGroup) -> "TPGaudiLlamaAttention": + tp_mha = TPGaudiLlamaAttention(config=mha.config, layer_idx=layer_idx, group=group) return tp_mha def pre_attn_forward( @@ -661,8 +662,8 @@ def pre_attn_forward( cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - - hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward(self, + hidden_states, attn_weights, present_key_value = GaudiLlamaAttention.pre_attn_forward( + self, hidden_states, attention_mask, position_ids, @@ -678,11 +679,13 @@ def pre_attn_forward( flash_attention_causal_mask, flash_attention_fast_softmax, cache_idx, - **kwargs - ) + **kwargs, + ) hidden_states = reduce_from_tensor_model_parallel_region(hidden_states) return hidden_states, attn_weights, present_key_value + + class GaudiLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super(LlamaDecoderLayer, self).__init__() @@ -854,17 +857,14 @@ def __init__(self, config: LlamaConfig): super(LlamaModel, self).__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.distributed_strategy = config.distributed_strategy - config.distributed_strategy = None self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) layers = [] for layer_idx in range(config.num_hidden_layers): layer = GaudiLlamaDecoderLayer(config, layer_idx) - layer = self.distributed_strategy.distribute_layer(layer, layer_idx) + layer = config.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) - + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1100,9 +1100,9 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): """ def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): - config.distributed_strategy = distributed_strategy + config.distributed_strategy = distributed_strategy super().__init__(config) - + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) @@ -1146,7 +1146,6 @@ def forward( global has_fused_rope has_fused_rope = False - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 4e116242f5..aba9e0bba9 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -71,6 +71,9 @@ "torch_compile_distributed": [ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], + "distributed_tp": [ + ("/mnt/weka/data/pytorch//llama2/Llama-2-7b-hf/", 1856.8140409694543), + ], } else: # Gaudi1 CI baselines @@ -101,6 +104,7 @@ ], "torch_compile": [], "torch_compile_distributed": [], + "distributed_tp": [], } @@ -116,6 +120,7 @@ def _test_text_generation( fp8: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, + distributed_strategy: str = None, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -127,6 +132,11 @@ def _test_text_generation( "--use_deepspeed", f"--world_size {world_size}", ] + elif distributed_strategy == "tp": + command += [ + f"{path_to_example_dir / 'gaudi_spawn.py'}", + f"--world_size {world_size}", + ] command += [ f"{path_to_example_dir / 'text-generation' / 'run_generation.py'}", @@ -148,11 +158,13 @@ def _test_text_generation( if "starcoder2" in model_name.lower(): command += ["--flash_attention_recompute"] - if reuse_cache or torch_compile: + if (reuse_cache or torch_compile) and not distributed_strategy == "tp": command += ["--reuse_cache"] if torch_compile: command += ["--torch_compile"] + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: @@ -194,6 +206,10 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] + if distributed_strategy is not None: + command += [ + f"--distributed_strategy={distributed_strategy}", + ] with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") @@ -294,6 +310,21 @@ def test_text_generation_torch_compile_distributed(model_name: str, baseline: fl _test_text_generation(model_name, baseline, token, deepspeed=True, world_size=world_size, torch_compile=True) +@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["distributed_tp"]) +def test_text_generation_distributed_tp(model_name: str, baseline: float, token: str): + world_size = 8 + _test_text_generation( + model_name, + baseline, + token, + batch_size=64, + max_input_tokens=128, + world_size=world_size, + torch_compile=True, + distributed_strategy="tp", + ) + + class TextGenPipeline(TestCase): def test_text_generation_pipeline_script(self): path_to_script = ( From cca19c293258644688261147e9e785ddd46323cb Mon Sep 17 00:00:00 2001 From: Kalyan Date: Tue, 23 Jul 2024 09:40:08 +0300 Subject: [PATCH 04/10] Updated LlamaConfig with distributed_strategy --- optimum/habana/transformers/models/llama/configuration_llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index dcba1c0738..82cdcce924 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -27,6 +27,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, fused_qkv=False, + distributed_strategy=None, **kwargs, ): super().__init__( @@ -55,3 +56,4 @@ def __init__( self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv + self.distributed_strategy = distributed_strategy From dbb316b1a69d1a86c4ea890c5f05e65834cf60b2 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Thu, 25 Jul 2024 10:06:19 +0300 Subject: [PATCH 05/10] Updated README.md and test data-set path --- examples/text-generation/README.md | 13 ++++++++----- tests/test_text_generation_example.py | 7 ++++--- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index efee8aac45..de27cf4caa 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,13 +264,16 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. -### Running with Tesor parallel strategy -#### Attribution +### Running with tesor-parallel strategy -This repository includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. +```bash +NOTE: This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. +``` -torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. To enable -torch.compile with tensor parallel strategy, please set the following environment variables before running the +```bash +WARNING: torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. +``` +To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in your command. diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index aba9e0bba9..e28d671626 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -72,7 +72,7 @@ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], "distributed_tp": [ - ("/mnt/weka/data/pytorch//llama2/Llama-2-7b-hf/", 1856.8140409694543), + ("meta-llama/Llama-2-7b-hf", 1856.8140409694543), ], } else: @@ -163,8 +163,9 @@ def _test_text_generation( if torch_compile: command += ["--torch_compile"] - command += ["--use_flash_attention"] - command += ["--flash_attention_recompute"] + if distributed_strategy == "tp": + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: From 182baf553e50e1d0c85549e1a9ccd67cb70c10c0 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Fri, 26 Jul 2024 14:41:44 +0300 Subject: [PATCH 06/10] distributed_strategy is not JSON serializable --- optimum/habana/transformers/models/llama/modeling_llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c071dea0e2..6e0979afb2 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -861,9 +861,12 @@ def __init__(self, config: LlamaConfig): layers = [] for layer_idx in range(config.num_hidden_layers): layer = GaudiLlamaDecoderLayer(config, layer_idx) - layer = config.distributed_strategy.distribute_layer(layer, layer_idx) + if config.distributed_strategy is not None: + layer = config.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) + #distributed_strategy is not JSON serializable + config.distributed_strategy = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False From 524d1ccf08fcaaa769da3a5ebabe54a2e2619031 Mon Sep 17 00:00:00 2001 From: Kalyan Date: Mon, 29 Jul 2024 13:15:14 +0300 Subject: [PATCH 07/10] Updated README.md make style changes --- examples/text-generation/README.md | 15 +++++++-------- .../transformers/models/llama/modeling_llama.py | 2 +- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index de27cf4caa..b122091f27 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -264,15 +264,14 @@ set the following environment variables before running the command: `PT_ENABLE_I You will also need to add `--torch_compile` in your command. -### Running with tesor-parallel strategy +### Running with tensor-parallel strategy -```bash -NOTE: This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. -``` +> [!NOTE] +> This strategy includes code from the [foundation-model-stack](https://github.com/foundation-model-stack/foundation-model-stack) repository, which is licensed under the Apache License 2.0. See the `LICENSE` file for more details. + +> [!WARNING] +> torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. -```bash -WARNING: torch.compile with tensor parallel strategy is an experimental feature. It has not been validated for all models. -``` To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. @@ -280,7 +279,7 @@ You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in Here is an example: ```bash -python ../gaudi_spawn.py --world_size 8 run_generation.py \ +PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_size 8 run_generation.py \ --model_name_or_path meta-llama/Llama-2-70b-hf \ --trim_logits \ --use_kv_cache \ diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 6e0979afb2..30ab4d4aa2 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -865,7 +865,7 @@ def __init__(self, config: LlamaConfig): layer = config.distributed_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) - #distributed_strategy is not JSON serializable + # distributed_strategy is not JSON serializable config.distributed_strategy = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) From bb19c4eda13c2c053ddda22dff293ed8bbcb72fe Mon Sep 17 00:00:00 2001 From: Kalyan Date: Tue, 30 Jul 2024 11:18:10 +0300 Subject: [PATCH 08/10] Renamed the distributed_strategy parameter to parallel_strategy --- examples/text-generation/README.md | 4 ++-- examples/text-generation/run_generation.py | 4 ++-- examples/text-generation/utils.py | 8 ++++---- .../models/llama/configuration_llama.py | 4 ++-- .../transformers/models/llama/modeling_llama.py | 12 ++++++------ tests/test_text_generation_example.py | 14 +++++++------- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index b122091f27..fad94bdbcb 100755 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -275,7 +275,7 @@ You will also need to add `--torch_compile` in your command. To enable torch.compile with tensor parallel strategy, please set the following environment variables before running the command: `PT_ENABLE_INT64_SUPPORT=1` and `PT_HPU_LAZY_MODE=0`. This will enable tensor parallel strategy without deepspeed. -You will also need to add `--torch_compile` and `--distributed_strategy="tp"` in your command. +You will also need to add `--torch_compile` and `--parallel_strategy="tp"` in your command. Here is an example: ```bash @@ -293,7 +293,7 @@ PT_ENABLE_INT64_SUPPORT=1 PT_HPU_LAZY_MODE=0 python ../gaudi_spawn.py --world_s --max_input_tokens 2048 \ --max_new_tokens 2048 \ --torch_compile \ ---distributed_strategy="tp" +--parallel_strategy="tp" ``` ### Running with FP8 diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 2d1c4bcc07..c41664ebf3 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -288,11 +288,11 @@ def setup_parser(parser): help="Whether to trust the execution of code from datasets/models defined on the Hub. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.", ) parser.add_argument( - "--distributed_strategy", + "--parallel_strategy", type=str, choices=["tp", "none"], # Add other strategies as needed default="none", - help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", + help="Run multi card with the specified parallel strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 0c9d86cfc7..9e766f5626 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -270,7 +270,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): logger.info("Creating Model") config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) model_kwargs = {} - model_kwargs["distributed_strategy"] = TensorParallelStrategy() + model_kwargs["parallel_strategy"] = TensorParallelStrategy() model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype, **model_kwargs) initial_device = torch.device("cpu") @@ -281,7 +281,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): lazy_sd = serialization.load_state_dict( args.model_name_or_path, source=source, - distributed_strategy=args.distributed_strategy, + distributed_strategy=args.parallel_strategy, checkpoint_sharding=None, initial_device=initial_device, rank=args.global_rank, @@ -294,7 +294,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): lazy_sd, architecture, source, - args.distributed_strategy, + args.parallel_strategy, checkpoint_sharding, initial_device, args.local_rank, @@ -588,7 +588,7 @@ def initialize_model(args, logger): setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed else setup_distributed_model(args, model_dtype, model_kwargs, logger) - if not args.distributed_strategy == "tp" + if not args.parallel_strategy == "tp" else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) diff --git a/optimum/habana/transformers/models/llama/configuration_llama.py b/optimum/habana/transformers/models/llama/configuration_llama.py index 82cdcce924..12ad78e29a 100644 --- a/optimum/habana/transformers/models/llama/configuration_llama.py +++ b/optimum/habana/transformers/models/llama/configuration_llama.py @@ -27,7 +27,7 @@ def __init__( attention_dropout=0.0, mlp_bias=False, fused_qkv=False, - distributed_strategy=None, + parallel_strategy=None, **kwargs, ): super().__init__( @@ -56,4 +56,4 @@ def __init__( self.mlp_bias = mlp_bias self.fused_qkv = fused_qkv - self.distributed_strategy = distributed_strategy + self.parallel_strategy = parallel_strategy diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 30ab4d4aa2..4630678a97 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -861,12 +861,12 @@ def __init__(self, config: LlamaConfig): layers = [] for layer_idx in range(config.num_hidden_layers): layer = GaudiLlamaDecoderLayer(config, layer_idx) - if config.distributed_strategy is not None: - layer = config.distributed_strategy.distribute_layer(layer, layer_idx) + if config.parallel_strategy is not None: + layer = config.parallel_strategy.distribute_layer(layer, layer_idx) layers.append(layer) self.layers = torch.nn.ModuleList(layers) - # distributed_strategy is not JSON serializable - config.distributed_strategy = None + # parallel_strategy is not JSON serializable + config.parallel_strategy = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False @@ -1102,8 +1102,8 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM): - add new args reuse_cache """ - def __init__(self, config, distributed_strategy: DistributedStrategy = NoOpStrategy): - config.distributed_strategy = distributed_strategy + def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy): + config.parallel_strategy = parallel_strategy super().__init__(config) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index e28d671626..29189aba32 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -120,7 +120,7 @@ def _test_text_generation( fp8: bool = False, max_input_tokens: int = 0, max_output_tokens: int = 100, - distributed_strategy: str = None, + parallel_strategy: str = None, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -132,7 +132,7 @@ def _test_text_generation( "--use_deepspeed", f"--world_size {world_size}", ] - elif distributed_strategy == "tp": + elif parallel_strategy == "tp": command += [ f"{path_to_example_dir / 'gaudi_spawn.py'}", f"--world_size {world_size}", @@ -158,12 +158,12 @@ def _test_text_generation( if "starcoder2" in model_name.lower(): command += ["--flash_attention_recompute"] - if (reuse_cache or torch_compile) and not distributed_strategy == "tp": + if (reuse_cache or torch_compile) and not parallel_strategy == "tp": command += ["--reuse_cache"] if torch_compile: command += ["--torch_compile"] - if distributed_strategy == "tp": + if parallel_strategy == "tp": command += ["--use_flash_attention"] command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" @@ -207,9 +207,9 @@ def _test_text_generation( f"--max_input_tokens {max_input_tokens}", "--limit_hpu_graphs", ] - if distributed_strategy is not None: + if parallel_strategy is not None: command += [ - f"--distributed_strategy={distributed_strategy}", + f"--parallel_strategy={parallel_strategy}", ] with TemporaryDirectory() as tmp_dir: @@ -322,7 +322,7 @@ def test_text_generation_distributed_tp(model_name: str, baseline: float, token: max_input_tokens=128, world_size=world_size, torch_compile=True, - distributed_strategy="tp", + parallel_strategy="tp", ) From 6a04a8c1c872a297b8e9ad77aa5d009697e2974a Mon Sep 17 00:00:00 2001 From: Kalyan Date: Tue, 30 Jul 2024 16:03:45 +0300 Subject: [PATCH 09/10] Updated cache_dir for parallel_strategy = tp --- examples/text-generation/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 9e766f5626..e7c3bb1d46 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -248,7 +248,7 @@ def setup_model(args, model_dtype, model_kwargs, logger): return model, assistant_model -def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): +def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir): from typing import Any, MutableMapping from optimum.habana.distributed import serialization @@ -279,7 +279,7 @@ def setup_distributed_model_tp(args, model_dtype, model_kwargs, logger): lazy_sd: MutableMapping[str, Any] = {} logger.info("Loading Checkpoints") lazy_sd = serialization.load_state_dict( - args.model_name_or_path, + cache_dir, source=source, distributed_strategy=args.parallel_strategy, checkpoint_sharding=None, @@ -566,7 +566,7 @@ def initialize_model(args, logger): setup_env(args) setup_device(args) set_seed(args.seed) - get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) + cache_dir = get_repo_root(args.model_name_or_path, local_rank=args.local_rank, token=args.token) if args.assistant_model is not None: get_repo_root(args.assistant_model, local_rank=args.local_rank, token=args.token) use_deepspeed = args.world_size > 0 @@ -589,7 +589,7 @@ def initialize_model(args, logger): if not use_deepspeed else setup_distributed_model(args, model_dtype, model_kwargs, logger) if not args.parallel_strategy == "tp" - else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger) + else setup_distributed_model_tp(args, model_dtype, model_kwargs, logger, cache_dir) ) tokenizer, model, assistant_model = setup_tokenizer(args, model, assistant_model) generation_config = setup_generation_config(args, model, assistant_model, tokenizer) From 3fd17e7c25917ca96b5137b5b21890e8ba50ac1d Mon Sep 17 00:00:00 2001 From: Kalyan Date: Tue, 30 Jul 2024 17:43:28 +0300 Subject: [PATCH 10/10] Updated the perf number for test test_text_generation_distributed_tp --- tests/test_text_generation_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 29189aba32..b32288ffac 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -72,7 +72,7 @@ ("meta-llama/Llama-2-7b-hf", 39.72973199515235), ], "distributed_tp": [ - ("meta-llama/Llama-2-7b-hf", 1856.8140409694543), + ("meta-llama/Llama-2-7b-hf", 1345.2369318328463), ], } else: