Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status
--|--|--|--|--
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface<br>AutoGPTQ | Apache 2.0<br>MIT | Beta
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon

## Usage with FMS HF Tuning
Expand Down Expand Up @@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark

See below CSV files for various results:
- [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv).
- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv).

### Code Architecture

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]


# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
Expand Down Expand Up @@ -123,6 +124,7 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module


# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
Expand All @@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call(
):
# patch old_forward to view attribtues to torch_dype
# before call

if submodule_names is None:
submodule_names = ''
submodule_names = ""
if isinstance(submodule_names, str):
submodule_names = [submodule_names]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
AutoGPTQForCausalLM,
BaseQuantizeConfig,
)
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)

# Local
from .autogptq_utils import ( #pylint: disable=import-outside-toplevel
patch_forward_to_view_attributes_before_call,
PATCH_FOR_FSDP_TRITON_V2
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
PATCH_FOR_FSDP_TRITON_V2,
patch_forward_to_view_attributes_before_call,
)

# Currently we allow only a quantized checkpoint to be loaded, we do not
Expand Down Expand Up @@ -214,8 +219,14 @@ def augmentation(
):
# guarded imports
# Third Party
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
QuantLinear,
)
from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
GPTQLoraModel,
get_gptq_peft_model,
)

# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
Expand Down
36 changes: 11 additions & 25 deletions plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:


1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth).
1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
- Low-Rank Adapter Fused Operations
- Fast RoPE Triton Kernels
- Fast RMS LayerNorm Triton Kernels
Expand All @@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t

Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅
[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅

### Code Extracted from Unsloth

<!--
NOTE: the
- fused_ops/unsloth_lora -> unsloth main
* utils (fast_dequant, fast_gemv, fast_linear_forward, matmul_lora)
* geglu, swiglu (this can be reused across other models, but currently used inside MLP fused ops only)
* bnb (fast_lora.py)
* gtqp (fast_lora, triton) -> jeromeku
- kernels
* cross_ent, rms, rope -> unsloth main
-->

Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
```
it would require a commercial license if used to run on more than 4 GPUs, see
https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143
it would require a commercial license if used to run on more than 4 GPUs ...
```
- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc).
* These model files are **not extracted**.
- All code extracted here before the Feb 2024 Release, see table below.
- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
- All extracted code appears before the Feb 2024 Release.
- In the table below we record what was extracted, and the exact commit from which it was taken.

Path | Description | Extracted From | Modifications | Date
--|--|--|--|--
[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`<br>`triton/layers.py` | 6 Feb 2024
[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024

<!--
[models/](./src/fms_accelerate_unsloth/models/) | Model Forwards | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc)<br><br>`tohrnii/mixtral` @ [a55b7400](https://github.com/tohrnii/unsloth/commit/a55b740062b4fc8ce8f5196bfabe3cf860020ca7) | `llama.py`<br>`mistral.py`<br>`mixtral.py`| `llama.py`<br>`mistral.py`<br>`mixtral.py` | 6 Feb 2024<br><br> 22 Feb 2024
-->

[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`<br>`rms_layernorm.py` | 28 Jan 2024

## Known Issues

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, Dict, Tuple

# Third Party
from accelerate.utils import set_module_tensor_to_device
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
Expand Down Expand Up @@ -63,9 +64,20 @@ def _all_reduce_hook(grad):
return grad

for mod in modules:
# NOTE: assuming lora has no bias
A = mod.lora_A.default
B = mod.lora_B.default

# install hooks on the adapters
mod.lora_A.default.weight.register_hook(_all_reduce_hook)
mod.lora_B.default.weight.register_hook(_all_reduce_hook)
A.weight.register_hook(_all_reduce_hook)
B.weight.register_hook(_all_reduce_hook)

# because we will ignore these from FSDP, we need to manually
# move them to gpu if they are already not on them
if not A.weight.is_cuda:
set_module_tensor_to_device(A, "weight", "cuda")
if not B.weight.is_cuda:
set_module_tensor_to_device(B, "weight", "cuda")


class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
Expand All @@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]):

self._base_layer = self._check_config_and_maybe_check_values(
key="peft.quantization.fused_ops_and_kernels.base_layer",
values=[
"auto_gptq",
# "bitsandbytes" # enable later when we have BNB implemented
],
values=["auto_gptq", "bitsandbytes"],
)

# only support these at the moment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,10 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass

# added by [email protected]
# this will be patchable on the actual module
def apply_lora_o_v2(self, X):
OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
Original file line number Diff line number Diff line change
Expand Up @@ -735,3 +735,10 @@ def apply_lora_o(self, X):
Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
return O

# added by [email protected]
# this version can be directly patched on the output linear
def apply_lora_o_v2(self, X):
Oqstate, OA, OB, OS = get_lora_parameters(self)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
return O
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,9 @@ def backward(ctx, dY):
pass
pass


def fast_rope_embedding(Q, K, cos, sin):
# modified by [email protected]
# NOTE: fast_rope embeddings currently does not account for position ids
def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
return Q, K
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Local
from .model_patcher import ModelPatcher

PATCHES = [".models.llama", ".models.mistral"]
PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
PLUGIN_PREFIX = "fms_acceleration_foak"

# TODO: remove the need for the prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@
from functools import partial

# Third Party
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaMLP,
LlamaRMSNorm,
)

# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
from .utils import build_lora_fused_ops, trigger_fused_ops
from .model_patcher import (
ModelPatcher,
ModelPatcherRule,
ModelPatcherTrigger,
combine_functions,
combine_triggers,
)
from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops

# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
Expand All @@ -42,18 +52,54 @@
ModelPatcher.register(
ModelPatcherRule(
rule_id="llama-qkvo",
trigger=combine_triggers(
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["q_proj", "k_proj", "v_proj"],
)
),
ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
submodule_names=["o_proj"],
)
),
logic="OR",
),
forward_builder=combine_functions(
partial(
build_lora_fused_ops,
submodule_names=["q_proj", "k_proj", "v_proj"],
fused_op=KEY_QKV,
),
partial(
build_lora_fused_ops,
submodule_names=["o_proj"],
fused_op=KEY_O,
),
logic="APPEND",
),
forward_builder_args=["base_type"],
)
)

ModelPatcher.register(
ModelPatcherRule(
rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
attn_cls=LlamaAttention,
qkv_module_names=["q_proj", "k_proj", "v_proj"],
o_module_name="o_proj",
attn_cls=LlamaMLP,
submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
qkv_module_names=["q_proj", "k_proj", "v_proj"],
o_module_name="o_proj",
submodule_names=["up_proj", "down_proj", "gate_proj"],
fused_op=KEY_MLP,
),
forward_builder_args=["base_type"],
)
Expand Down
Loading