-
Notifications
You must be signed in to change notification settings - Fork 4.7k
OptimizedLinear implementation #5355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
97a230e
optimized linear + tests
jeffra 7afae40
Merge branch 'master' into optim-linear
jeffra 1f7006a
some fixes to make lora training work
sfc-gh-rsamdani 78e763d
clean-up
jeffra fa0b032
Merge branch 'master' into optim-linear
jeffra 16d72d4
Merge branch 'master' into optim-linear
jeffra 980db87
address comments, new tests, new fixes
jeffra ffe2223
Merge branch 'master' into optim-linear
jeffra 185a68f
add type check for configs
jeffra dc9258c
formatting
jeffra fce174b
Merge branch 'master' into optim-linear
sfc-gh-jrasley f2cbcae
Merge branch 'master' into optim-linear
jeffra efeed39
loosen typechecking of config
jeffra File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| from .optimized_linear import OptimizedLinear | ||
| from .config import LoRAConfig, QuantizationConfig |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
|
|
||
| @dataclass | ||
| class LoRAConfig: | ||
| """ | ||
| Configuration settings for LoRAOptimizedLinear. | ||
|
|
||
| Attributes: | ||
| lora_r (int): LoRA attention dimension, also know as the rank. Defaults is 64. | ||
| lora_alpha (float): LoRA scaling factor, default is 16. | ||
| base_weight_sharding (int): The degree to which the base weights are sharded, | ||
| should typically be set to the data-parallel world size to maximize the memory | ||
| reduction benefits. Defaults to 1, which means this feature is disabled. | ||
| """ | ||
| lora_r: int = 64 | ||
| lora_alpha: float = 16. | ||
| base_weight_sharding: int = 1 | ||
|
|
||
|
|
||
| @dataclass | ||
| class QuantizationConfig: | ||
| """ | ||
| Configuration settings for quantization for LoRAOptimizedLinear, QuantizedLinear, | ||
| and QuantizedParameter | ||
|
|
||
| Attributes: | ||
| q_bits (int): The number of bits used for quantization. Default is 8. | ||
| mantissa_bits (int): The number of bits reserved for the mantissa in fixed-point quantization. Default is 3. | ||
| group_size (int): The size of the group used for quantization. Default is 512. | ||
| """ | ||
| q_bits: int = 8 | ||
| mantissa_bits: int = 3 | ||
| group_size: int = 512 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import torch | ||
| import math | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from dataclasses import is_dataclass | ||
| from deepspeed.accelerator import get_accelerator | ||
| import deepspeed.comm as dist | ||
|
|
||
| from .config import LoRAConfig, QuantizationConfig | ||
| from .quantization import QuantizedParameter, QuantizedLinear | ||
|
|
||
|
|
||
| class OptimizedLinear(nn.Module): | ||
| """ | ||
| Optimized version of nn.Linear that adds features such as: | ||
| * LoRA w. base weight sharding | ||
| * FP [6,8,12] quantization | ||
|
|
||
| Arguments: | ||
| input_dim: Required: size of each input sample | ||
| output_dim: Required: size of each output sample | ||
| bias: Optional: If set to False, the layer will not learn an additive bias. Default: False | ||
| lora_config: Optional: LoRAConfig defining lora features and base-weight-sharding degree | ||
| quantization_config: Optional: QuantizationConfig defining quantization features | ||
| dtype: Optional: parameter dtype, only supports bfloat16 currently | ||
|
|
||
| Returns: | ||
| Returns a new nn.Module depending on the input config. Either native | ||
| torch.nn.Linear, QuantizedLinear, or the full-featured DSOptimizedLinear. | ||
| """ | ||
|
|
||
| def __new__(self, | ||
| input_dim: int, | ||
| output_dim: int, | ||
| bias: bool = False, | ||
| lora_config: LoRAConfig = None, | ||
| quantization_config: QuantizationConfig = None, | ||
| dtype=torch.bfloat16): | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if quantization_config is not None and not is_dataclass(quantization_config): | ||
| raise ValueError(f"Expecting QuantizationConfig but received {type(quantization_config)}") | ||
| if lora_config is not None and not is_dataclass(lora_config): | ||
| raise ValueError(f"Expecting LoRAConfig but received {type(lora_config)}") | ||
| if lora_config is None and quantization_config is None: | ||
| # Everything disabled, fall back to normal nn.Linear | ||
| self = nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) | ||
|
|
||
| elif lora_config: | ||
| # lora enabled, quantization may or may not be | ||
| self = LoRAOptimizedLinear(input_dim=input_dim, | ||
| output_dim=output_dim, | ||
| bias=bias, | ||
| lora_config=lora_config, | ||
| quantization_config=quantization_config, | ||
| dtype=dtype) | ||
|
|
||
| elif quantization_config: | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # only quantization enabled, no lora | ||
| self = QuantizedLinear(input_dim=input_dim, | ||
| output_dim=output_dim, | ||
| bias=bias, | ||
| quantization_config=quantization_config, | ||
| dtype=dtype) | ||
| return self | ||
|
|
||
|
|
||
| class LoRAOptimizedLinear(nn.Module): | ||
|
|
||
| def __init__(self, | ||
| input_dim: int, | ||
| output_dim: int, | ||
| bias: bool = False, | ||
| lora_config: LoRAConfig = None, | ||
| quantization_config: QuantizationConfig = None, | ||
| device=None, | ||
| dtype=torch.bfloat16): | ||
| super().__init__() | ||
| self.input_dim = input_dim | ||
| self.output_dim = output_dim | ||
| self.bias = bias | ||
| self.lora_config = lora_config | ||
| self.quantization_config = quantization_config | ||
| device = get_accelerator().current_device() if device is None else device | ||
| assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" | ||
|
|
||
| self.zero_shards = self.lora_config.base_weight_sharding | ||
| self.sharded_weight_size = int(float(self.input_dim) // self.zero_shards) | ||
| w = torch.nn.Parameter(torch.empty((self.output_dim, self.sharded_weight_size), dtype=dtype)) | ||
| torch.nn.init.xavier_uniform_(w) | ||
|
|
||
| if self.quantization_config is not None: | ||
| assert dtype == torch.bfloat16, "only bfloat16 is supported when using quantization" | ||
| self.base_weight = QuantizedParameter(w, quantization_config=quantization_config) | ||
| else: | ||
| self.base_weight = w | ||
|
|
||
| self.base_weight.requires_grad = False | ||
|
|
||
| # Use RS lora for now. | ||
| self.lora_scaling_factor = self.lora_config.lora_alpha / math.sqrt(self.lora_config.lora_r) | ||
| # Keeping lora weights in bf16 precision for ease of training. | ||
| self.lora_weight_1 = nn.Linear(self.input_dim, | ||
| self.lora_config.lora_r, | ||
| bias=self.bias, | ||
| device=device, | ||
| dtype=dtype) | ||
| self.lora_weight_2 = nn.Linear(self.lora_config.lora_r, | ||
| self.output_dim, | ||
| bias=self.bias, | ||
| device=device, | ||
| dtype=dtype) | ||
| self.lora_weight_1.weight.requires_grad = True | ||
| self.lora_weight_2.weight.requires_grad = True | ||
|
|
||
| def full_weight(self): | ||
| # This assumes weights are evenly sharded across gpus. which might not be correct. | ||
| # in that case, we should flatten before all_gather. | ||
| local_weight = self.base_weight.dequantized() if isinstance(self.base_weight, | ||
| QuantizedParameter) else self.base_weight | ||
| tensor_list = [ | ||
| torch.zeros_like(local_weight, device=local_weight.device, dtype=local_weight.dtype) | ||
| for _ in range(self.zero_shards) | ||
| ] | ||
| dist.all_gather(tensor_list, local_weight) | ||
| weight = nn.Parameter(torch.cat([tensor for tensor in tensor_list], dim=1)) | ||
| return weight | ||
|
|
||
| def linear_without_F_linear(self, input, weight): | ||
| output = torch.mm(input.reshape(-1, input.shape[-1]), weight) | ||
| output = output.view(*input.shape[:-1], weight.shape[1]) | ||
| return output | ||
|
|
||
| def forward(self, input_tensor): | ||
| # Gather the sharded base weight | ||
| if self.zero_shards > 1: | ||
| with torch.no_grad(): | ||
| base_weight = self.full_weight() | ||
| elif self.quantization_config: | ||
| base_weight = self.base_weight.dequantized() | ||
| else: | ||
| base_weight = self.base_weight | ||
|
|
||
| base_weight_output = F.linear(input_tensor, base_weight) | ||
| lora_output = self.lora_weight_2(self.lora_weight_1(input_tensor)) | ||
| return base_weight_output + self.lora_scaling_factor * lora_output | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| # DeepSpeed Team | ||
|
|
||
| import copy | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from typing import Optional | ||
|
|
||
| from deepspeed.accelerator import get_accelerator | ||
| from deepspeed.ops.fp_quantizer import Quantizer, FP_Quantize | ||
| from .config import QuantizationConfig | ||
|
|
||
|
|
||
| class QuantizedParameter(nn.Parameter): | ||
| """ | ||
| Quantized parameter class that implements weight quantization. Weights | ||
| are stored in quantized form on GPUs, and can be dequantized on-the-fly when | ||
| needed by the model. The weights are actually quantized during any `.to(device)`. | ||
|
|
||
| Arguments: | ||
| data (Tensor): parameter tensor. | ||
| requires_grad (bool, optional): if the parameter requires gradient. Defaults | ||
| to False and is not supported to be True. Argument provided only for interface | ||
| compatibility with torch.nn.Parameter. | ||
| quantization_config (QuantizationConfig, optional): | ||
| quantizer (Quantizer, optional): Defaults to FP_Quantize but can be any quantizer | ||
| that implements deepspeed.ops.fp_quantizer.Quantizer. This argument is also | ||
| required since the quantizer is stashed in the Parameter itself, some models | ||
| may clone the Parameter by passing an attribute __dict__. For an example, see | ||
| tests/unit/linear/test_quant_param.py::TestQuantParam::test_hf_clone | ||
| """ | ||
|
|
||
| def __new__( | ||
| cls, | ||
| data: Optional[torch.Tensor] = None, | ||
| requires_grad: bool = False, # quantized weights must be frozen | ||
| quantization_config: QuantizationConfig = None, | ||
| quantizer: Quantizer = None, | ||
| ): | ||
| if requires_grad: | ||
| raise ValueError(f"requires_grad=True is not supported with QuantizedParameter") | ||
| if data is None: | ||
| data = torch.empty(0) | ||
| self = torch.Tensor._make_subclass(cls, data, requires_grad) | ||
| self.quantization_config = QuantizationConfig() if quantization_config is None else quantization_config | ||
| if quantizer is not None: | ||
| self.quantizer = quantizer | ||
| else: | ||
| # if FPQuantizerBuilder is not compatible in this env this init will fail | ||
| self.quantizer = FP_Quantize(group_size=self.quantization_config.group_size) | ||
| self._ensure_quantized(self) | ||
| return self | ||
|
|
||
| def _ensure_quantized(self, tensor: torch.Tensor): | ||
| # If the tensor is on the accelerator and is not quantized, then quantize it in-place. | ||
| if get_accelerator().on_accelerator(tensor) and tensor.dtype != torch.int8: | ||
| with get_accelerator().stream(get_accelerator().current_stream(tensor.device)): | ||
| tensor.data = self.quantizer.quantize(tensor.data, | ||
| q_bits=self.quantization_config.q_bits, | ||
| q_mantisa_bits=self.quantization_config.mantissa_bits) | ||
| assert tensor.dtype == torch.int8 | ||
|
|
||
| def dequantized(self) -> torch.Tensor: | ||
| """ | ||
| Return a tensor containing the dequantized weights of this parameter. | ||
| """ | ||
| if get_accelerator().on_accelerator(self.data) and self.data.dtype == torch.int8: | ||
| with get_accelerator().stream(get_accelerator().current_stream(self.data.device)): | ||
| return self.quantizer.dequantize(self.data, | ||
| q_bits=self.quantization_config.q_bits, | ||
| q_mantisa_bits=self.quantization_config.mantissa_bits) | ||
| return self.data | ||
|
|
||
| def __getstate__(self): | ||
| state = self.__dict__ | ||
| state["data"] = self.data | ||
| state["quantization_config"] = self.quantization_config | ||
| state["requires_grad"] = self.requires_grad | ||
| return state | ||
|
|
||
| def __setstate__(self, state): | ||
| self.quantizer = state["quantizer"] | ||
| self.quantization_config = state["quantization_config"] | ||
| self.data = state["data"] | ||
| self.requires_grad = state["requires_grad"] | ||
|
|
||
| def __deepcopy__(self, memo): | ||
| new_instance = type(self).__new__(type(self)) | ||
| state = self.__getstate__() | ||
| new_instance.__setstate__(state) | ||
| new_instance.quantizer = copy.deepcopy(state["quantizer"]) | ||
| new_instance.quantization_config = copy.deepcopy(state["quantization_config"]) | ||
| new_instance.data = copy.deepcopy(state["data"]) | ||
| return new_instance | ||
|
|
||
| def __copy__(self): | ||
| new_instance = type(self).__new__(type(self)) | ||
| state = self.__getstate__() | ||
| new_instance.__setstate__(state) | ||
| return new_instance | ||
|
|
||
| def cuda(self, device=None, non_blocking=False): | ||
tjruwase marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return self.to(device="cuda" if device is None else device, non_blocking=non_blocking) | ||
|
|
||
| def to(self, *args, **kwargs): | ||
| """ | ||
| Move the parameter to the given device. Then, if the device is a cuda device, | ||
| quantize it. | ||
| """ | ||
| tensor = super().to(*args, **kwargs) | ||
| self._ensure_quantized(tensor) | ||
| return tensor | ||
|
|
||
|
|
||
| class QuantizedLinear(nn.Linear): | ||
| """ | ||
| Linear layer that implements weight quantization. Parameters | ||
| are stored via `QuantizedParameter` and are dequantized on-the-fly during any | ||
| forward pass. | ||
| """ | ||
|
|
||
| def __init__(self, | ||
| input_dim: int, | ||
| output_dim: int, | ||
| bias: bool = False, | ||
| quantization_config: QuantizationConfig = None, | ||
| dtype=torch.bfloat16): | ||
| super().__init__(input_dim, output_dim, bias=bias, dtype=dtype) | ||
| assert dtype == torch.bfloat16, "currently only supports bfloat16 dtype" | ||
| self.weight = QuantizedParameter(self.weight.data, quantization_config=quantization_config) | ||
|
|
||
| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| return F.linear(input, self.weight.dequantized(), self.bias) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,4 +3,4 @@ | |
|
|
||
| # DeepSpeed Team | ||
|
|
||
| from .quantize import FP_Quantize | ||
| from .quantize import FP_Quantize, Quantizer | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.