diff --git a/README.md b/README.md
index a7534ed1..707c8662 100644
--- a/README.md
+++ b/README.md
@@ -31,7 +31,7 @@ Plugin | Description | Depends | License | Status
--|--|--|--|--
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Beta
-[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
+[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
## Usage with FMS HF Tuning
@@ -174,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark
See below CSV files for various results:
- [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv).
-- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv).
### Code Architecture
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
index b8a7558d..913a6b7e 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
@@ -28,6 +28,7 @@
# consider making a map if patching more kernels
PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]
+
# This function may be moved after merging
# https://github.com/foundation-model-stack/fms-acceleration/pull/25
def _patch_target_module(
@@ -123,6 +124,7 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module
+
# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
@@ -133,9 +135,9 @@ def patch_forward_to_view_attributes_before_call(
):
# patch old_forward to view attribtues to torch_dype
# before call
-
+
if submodule_names is None:
- submodule_names = ''
+ submodule_names = ""
if isinstance(submodule_names, str):
submodule_names = [submodule_names]
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
index 30492a2b..7928d9a9 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
@@ -51,13 +51,18 @@ def __init__(self, configurations: Dict[str, Dict]):
def model_loader(self, model_name: str, **kwargs):
# guarded imports
# Third Party
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #pylint: disable=import-outside-toplevel,import-error
- from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
+ from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
+ AutoGPTQForCausalLM,
+ BaseQuantizeConfig,
+ )
+ from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
+ QuantLinear,
+ )
# Local
- from .autogptq_utils import ( #pylint: disable=import-outside-toplevel
- patch_forward_to_view_attributes_before_call,
- PATCH_FOR_FSDP_TRITON_V2
+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
+ PATCH_FOR_FSDP_TRITON_V2,
+ patch_forward_to_view_attributes_before_call,
)
# Currently we allow only a quantized checkpoint to be loaded, we do not
@@ -214,8 +219,14 @@ def augmentation(
):
# guarded imports
# Third Party
- from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear #pylint: disable=import-outside-toplevel,import-error
- from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model #pylint: disable=import-outside-toplevel,import-error
+ from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
+ QuantLinear,
+ )
+ from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
+ GPTQLoraModel,
+ get_gptq_peft_model,
+ )
+
# Local
from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
create_new_module_peft,
diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md
index a1b01d94..a1777671 100644
--- a/plugins/fused-ops-and-kernels/README.md
+++ b/plugins/fused-ops-and-kernels/README.md
@@ -3,7 +3,7 @@
This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:
-1. Fused operations and kernels are extracted from [unsloth](#extracted-code-from-unsloth).
+1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
- Low-Rank Adapter Fused Operations
- Fast RoPE Triton Kernels
- Fast RMS LayerNorm Triton Kernels
@@ -13,42 +13,28 @@ This library contains fused operations and custom kernels, to be expanded over t
Plugin | Description | Depends | Loading | Augmentation | Callbacks
--|--|--|--|--|--
-[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | Loads fused lora, fast cross-entropy, fast rms, fast RoPE | | | ✅
+[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅
### Code Extracted from Unsloth
-
Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
-- while unsloth is released under Apache 2.0, there are [exceptions to the permissive licenses scattered in the code base](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
+- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
```
- it would require a commercial license if used to run on more than 4 GPUs, see
- https://github.com/unslothai/unsloth/blob/d215fd902cf28feb8abcfde2d25281d0fbf9d28c/unsloth/models/llama.py#L1140-L1143
+ it would require a commercial license if used to run on more than 4 GPUs ...
```
-- these exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55), around the model files (namely `llama.py`, `mistral.py`, etc).
- * These model files are **not extracted**.
-- All code extracted here before the Feb 2024 Release, see table below.
+- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
+- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
+- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
+- All extracted code appears before the Feb 2024 Release.
+- In the table below we record what was extracted, and the exact commit from which it was taken.
Path | Description | Extracted From | Modifications | Date
--|--|--|--|--
[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
-[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
+[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024
-[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py` | 28 Jan 2024
-
-
-
+[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`
`rms_layernorm.py` | 28 Jan 2024
## Known Issues
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
index ad0a399c..7eab87f0 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
@@ -16,6 +16,7 @@
from typing import Callable, Dict, Tuple
# Third Party
+from accelerate.utils import set_module_tensor_to_device
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig
from peft.tuners.lora.layer import LoraLayer
@@ -63,9 +64,20 @@ def _all_reduce_hook(grad):
return grad
for mod in modules:
+ # NOTE: assuming lora has no bias
+ A = mod.lora_A.default
+ B = mod.lora_B.default
+
# install hooks on the adapters
- mod.lora_A.default.weight.register_hook(_all_reduce_hook)
- mod.lora_B.default.weight.register_hook(_all_reduce_hook)
+ A.weight.register_hook(_all_reduce_hook)
+ B.weight.register_hook(_all_reduce_hook)
+
+ # because we will ignore these from FSDP, we need to manually
+ # move them to gpu if they are already not on them
+ if not A.weight.is_cuda:
+ set_module_tensor_to_device(A, "weight", "cuda")
+ if not B.weight.is_cuda:
+ set_module_tensor_to_device(B, "weight", "cuda")
class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
@@ -82,10 +94,7 @@ def __init__(self, configurations: Dict[str, Dict]):
self._base_layer = self._check_config_and_maybe_check_values(
key="peft.quantization.fused_ops_and_kernels.base_layer",
- values=[
- "auto_gptq",
- # "bitsandbytes" # enable later when we have BNB implemented
- ],
+ values=["auto_gptq", "bitsandbytes"],
)
# only support these at the moment
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py
index 82f78f74..71d7070c 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py
@@ -394,3 +394,10 @@ def apply_lora_o(self, X):
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
return O
pass
+
+# added by flim@sg.ibm.com
+# this will be patchable on the actual module
+def apply_lora_o_v2(self, X):
+ OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
+ O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
+ return O
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py
index 3808fba7..ee5055ed 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py
@@ -735,3 +735,10 @@ def apply_lora_o(self, X):
Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj)
O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
return O
+
+# added by flim@sg.ibm.com
+# this version can be directly patched on the output linear
+def apply_lora_o_v2(self, X):
+ Oqstate, OA, OB, OS = get_lora_parameters(self)
+ O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
+ return O
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py
index 49b04fce..3577b586 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py
@@ -130,8 +130,9 @@ def backward(ctx, dY):
pass
pass
-
-def fast_rope_embedding(Q, K, cos, sin):
+# modified by flim@sg.ibm.com
+# NOTE: fast_rope embeddings currently does not account for position ids
+def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
return Q, K
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py
index 7d6df3bc..ebd49924 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py
@@ -15,7 +15,7 @@
# Local
from .model_patcher import ModelPatcher
-PATCHES = [".models.llama", ".models.mistral"]
+PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
PLUGIN_PREFIX = "fms_acceleration_foak"
# TODO: remove the need for the prefix
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
index 3d01311a..290d1217 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
@@ -16,14 +16,24 @@
from functools import partial
# Third Party
-from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaMLP,
+ LlamaRMSNorm,
+)
# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
-from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
-from .utils import build_lora_fused_ops, trigger_fused_ops
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
# TODO: have a generic version of this rule
# - do regex on RMSNorm class name
@@ -42,18 +52,54 @@
ModelPatcher.register(
ModelPatcherRule(
rule_id="llama-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=LlamaAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=LlamaAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
- attn_cls=LlamaAttention,
- qkv_module_names=["q_proj", "k_proj", "v_proj"],
- o_module_name="o_proj",
+ attn_cls=LlamaMLP,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
- qkv_module_names=["q_proj", "k_proj", "v_proj"],
- o_module_name="o_proj",
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ fused_op=KEY_MLP,
),
forward_builder_args=["base_type"],
)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py
index a8e6795f..37809fd1 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py
@@ -18,22 +18,22 @@
# Third Party
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
+ MistralMLP,
MistralRMSNorm,
)
# Local
from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
-from ..kernels.unsloth.rope_embedding import fast_rope_embedding as _fast_rope_embedding
-from .model_patcher import ModelPatcher, ModelPatcherRule, ModelPatcherTrigger
-from .utils import build_lora_fused_ops, trigger_fused_ops
-
-
-# NOTE: fast_rope_embedding does not work with position_ids
-# currently they are ignored
-def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
- return _fast_rope_embedding(Q, K, cos, sin)
-
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
# - do regex on RMSNorm class name
# - check on the tensors required for fast_rms_layernorm
@@ -45,29 +45,62 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
),
)
-# - do regex on Attention class name
-# - have a set of qkv / o module names and check on that
ModelPatcher.register(
ModelPatcherRule(
rule_id="mistral-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MistralAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MistralAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-mlp",
trigger=ModelPatcherTrigger(
check=partial(
trigger_fused_ops,
- attn_cls=MistralAttention,
- qkv_module_names=["q_proj", "k_proj", "v_proj"],
- o_module_name="o_proj",
+ attn_cls=MistralMLP,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
)
),
forward_builder=partial(
build_lora_fused_ops,
- qkv_module_names=["q_proj", "k_proj", "v_proj"],
- o_module_name="o_proj",
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ fused_op=KEY_MLP,
),
forward_builder_args=["base_type"],
)
)
-# - get the module_name and reload on that
ModelPatcher.register(
ModelPatcherRule(
rule_id="mistral-cross-ent",
@@ -79,9 +112,6 @@ def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
)
)
-# - get the module name
-# - check if "apply_rotary_pos_emb" exists
-# - patch
ModelPatcher.register(
ModelPatcherRule(
rule_id="mistral-rope",
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py
new file mode 100644
index 00000000..1522ef8d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py
@@ -0,0 +1,104 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from functools import partial
+
+# Third Party
+from transformers.models.mixtral.modeling_mixtral import (
+ MixtralAttention,
+ MixtralRMSNorm,
+)
+
+# Local
+from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
+from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
+
+# - do regex on RMSNorm class name
+# - check on the tensors required for fast_rms_layernorm
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-rms",
+ trigger=ModelPatcherTrigger(check=MixtralRMSNorm),
+ forward=fast_rms_layernorm,
+ ),
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MixtralAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MixtralAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-cross-ent",
+ import_and_maybe_reload=(
+ "torch.nn.CrossEntropyLoss",
+ FastCrossEntropyLoss,
+ "transformers.models.mixtral.modeling_mixtral",
+ ),
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-rope",
+ import_and_maybe_reload=(
+ "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb",
+ fast_rope_embedding,
+ None,
+ ),
+ )
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py
index 3355aa67..7f803330 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py
@@ -468,3 +468,28 @@ def patch_model(model: torch.nn.Module, **kwargs):
def patch_model_summary():
return ModelPatcher.summary()
+
+
+def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"):
+ assert logic == "OR", "only OR logic implemented for combining triggers"
+
+ # NOTE: this can be probably simplified
+ def _or_logic(*args, **kwargs):
+ for trig in triggers:
+ if trig.check(*args, **kwargs):
+ return True
+ return False
+
+ return ModelPatcherTrigger(check=_or_logic)
+
+
+def combine_functions(*funcs: Callable, logic: str = "APPEND"):
+ assert logic == "APPEND", "only APPEND logic implemented for combining functions"
+
+ def _append(*args, **kwargs):
+ results = []
+ for f in funcs:
+ results += f(*args, **kwargs)
+ return results
+
+ return _append
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py
index b048b8e4..10819fc0 100644
--- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py
@@ -1,34 +1,40 @@
-# Copyright The FMS HF Tuning Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
# Standard
+from functools import partial
from typing import Callable, List, Type
+import os
# Third Party
import torch
-import os
# Local
-# GPTQ imports
-from ..fused_ops.unsloth_lora.gptq.fast_lora import LoRA_W as LoRA_W_gptq
-from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq
-from ..fused_ops.unsloth_lora.gptq.fast_lora import (
- get_lora_parameters as get_lora_parameters_gptq,
+# NOTE: the default activation is swiglu in both cases
+from ..fused_ops.unsloth_lora.bnb.fast_lora import (
+ apply_lora_mlp_swiglu as fused_op_mlp_bnb,
)
-from ..fused_ops.unsloth_lora.gptq.fast_lora import unpack_gptqstate
+from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_o_v2 as fused_op_o_bnb
+from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_qkv as fused_op_qkv_bnb
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq
from .model_patcher import ModelPatcherTrigger
+KEY_QKV = "qkv"
+KEY_O = "o"
+KEY_MLP = "mlp"
+
+FUSED_OPS = {
+ "auto_gptq": {
+ KEY_QKV: fused_op_qkv_gptq,
+ KEY_O: fused_op_o_gptq,
+ KEY_MLP: fused_op_mlp_gptq,
+ },
+ "bitsandbytes": {
+ KEY_QKV: fused_op_qkv_bnb,
+ KEY_O: fused_op_o_bnb,
+ KEY_MLP: fused_op_mlp_bnb,
+ },
+}
+
# simple utility function to guess if its lora layer
def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
@@ -45,15 +51,15 @@ def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
# modules are called q_proj, k_proj, and v_proj, respectively.
# the fused operation can be changed, depending on what the base layer is
# i.e. gptq or bnb
-def _build_qkv_forwards(
+def _build_fused_forwards(
attn: torch.nn.Module,
fused_operation: Callable = fused_op_qkv_gptq,
- module_names: List[str] = None,
+ submodule_names: List[str] = None,
):
- if module_names is None:
- module_names = ["q_proj", "k_proj", "v_proj"]
+ # fused opts expected to produce singular or multiple results
+ # module names must be passed in order of what the fused
- Q = K = V = None
+ outs = {}
# the fused operation will be called on first one that passes in the
# input X.
@@ -61,62 +67,52 @@ def _build_qkv_forwards(
# - subsequent calls will be a no-op until ALL Q, K, V get reset to None
def _fused_op(X):
- nonlocal Q, K, V
- if Q is None and K is None and V is None:
- Q, K, V = fused_operation(attn, X)
+ # if all of the outs are not yet populated
+ if all(x not in outs for x in submodule_names):
+ fused_outs = fused_operation(attn, X)
+ try:
+ fused_outs = list(fused_outs) # not sure if this is correct
+ except TypeError:
+ # if fused_outs is not iterable
+ fused_outs = [fused_outs]
+ for n, x in zip(submodule_names, fused_outs):
+ outs[n] = x
# each of these functions
# - calls the fused op
# -
- error_msg = (
- "QKV fused_op needs to be first reset with sequential calls to each of them"
- )
-
- def _forward_q(self, X):
- nonlocal Q
- _fused_op(X)
- assert Q is not None, error_msg
- out, Q = Q, None # unload
- return out
-
- def _forward_k(self, X):
- nonlocal K
- _fused_op(X)
- assert K is not None, error_msg
- out, K = K, None # unload
- return out
- def _forward_v(self, X):
- nonlocal V
+ def _forward(self, X, name: str):
_fused_op(X)
- assert V is not None, error_msg
- out, V = V, None # unload
- return out
-
- return zip(module_names, [_forward_q, _forward_k, _forward_v])
-
+ assert (
+ name in outs
+ ), "Fused_op needs to be first reset with sequential calls to each of them"
+ V = outs[name]
+ del outs[name]
+ return V
-# fused ops for outputs for GPTQ
-def fused_op_o_gptq(self, X):
- Oqstate, OA, OB, OS = get_lora_parameters_gptq(self)
- O = LoRA_W_gptq.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
- return O
+ return zip(submodule_names, [partial(_forward, name=n) for n in submodule_names])
-# TODO: add the MLP
def build_lora_fused_ops(
attn: torch.nn.Module,
base_type: str = "auto_gptq",
- qkv_module_names: List[str] = None,
- o_module_name: str = "o_proj",
+ submodule_names: List[str] = None,
+ fused_op: str = KEY_QKV,
):
- if qkv_module_names is None:
- qkv_module_names = ["q_proj", "k_proj", "v_proj"]
- # handle the QKVs
+ assert (
+ len(submodule_names) > 0
+ ), "When building lora fused ops requires more than one submodule."
+
+ if submodule_names is None:
+ submodule_names = ["q_proj", "k_proj", "v_proj"]
+
+ # get the fused op
+ fused_operation = FUSED_OPS[base_type][fused_op]
+
+ # handle casting issues
if base_type == "auto_gptq":
- _qkv_fused_op = fused_op_qkv_gptq
- _o_fused_op = fused_op_o_gptq
# this is required due to this FSDP fix
# https://github.com/foundation-model-stack/fms-acceleration/pull/15
@@ -131,55 +127,60 @@ def build_lora_fused_ops(
):
# guarded import
- from fms_acceleration_peft.autogptq_utils import ( #pylint: disable=import-outside-toplevel
- patch_forward_to_view_attributes_before_call,
- PATCH_FOR_FSDP_TRITON_V2
+ # pylint: disable=import-outside-toplevel,import-error
+ # Third Party
+ from fms_acceleration_peft.autogptq_utils import (
+ PATCH_FOR_FSDP_TRITON_V2,
+ patch_forward_to_view_attributes_before_call,
)
# patch each of the fused ops to view the attributes
# back into torch.int32
- # TODO: add the MLP fused op also
- _qkv_fused_op = patch_forward_to_view_attributes_before_call(
- _qkv_fused_op,
- PATCH_FOR_FSDP_TRITON_V2, torch.int32,
- submodule_names=[
- n + '.base_layer' for n in qkv_module_names
- ],
- is_method_forward=False,
- )
- _o_fused_op = patch_forward_to_view_attributes_before_call(
- _o_fused_op,
- PATCH_FOR_FSDP_TRITON_V2, torch.int32,
- submodule_names='base_layer',
+ # - if there are multiple submodules, then we assume that
+ # 'fused_operation' will be called on module that has
+ # submodules specified in 'submodule_names'.
+ # - otherwise if there is only a single 'submodule_name', then
+ # assume that 'fused_operation' called on the submodule specified
+ # by 'submodule_name' itself
+ if len(submodule_names) > 1:
+ patched_submodule_names = [n + ".base_layer" for n in submodule_names]
+ else:
+ # otherwise assume calling on the 'submodule_name' itself
+ # so its just the base layer.
+ patched_submodule_names = "base_layer"
+
+ fused_operation = patch_forward_to_view_attributes_before_call(
+ fused_operation,
+ PATCH_FOR_FSDP_TRITON_V2,
+ torch.int32,
+ submodule_names=patched_submodule_names,
is_method_forward=False,
)
- else:
- raise NotImplementedError(
- f"Cannot build fused ops for base type '{base_type}'."
- )
-
- trigger_and_forwards = [
- (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward)
- for name, forward in _build_qkv_forwards(
- attn,
- fused_operation=_qkv_fused_op,
- module_names=qkv_module_names,
- )
- ]
-
- # handle the self-attn output
- _output_module = getattr(attn, o_module_name)
- if _is_loralayer(_output_module):
- trigger_and_forwards.append(
+ if fused_op == KEY_QKV:
+ return [
+ (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward)
+ for name, forward in _build_fused_forwards(
+ attn,
+ fused_operation=fused_operation,
+ submodule_names=submodule_names,
+ )
+ ]
+ if fused_op == KEY_O:
+ # otherwise its just a single op
+ submodule_names = submodule_names[0]
+ return [
(
- ModelPatcherTrigger(check=_is_loralayer, module_name=o_module_name),
- _o_fused_op,
+ ModelPatcherTrigger(check=_is_loralayer, module_name=submodule_names),
+ fused_operation,
)
- )
+ ]
+ if fused_op == KEY_MLP:
+ # otherwise just return the fused_op that should be attached at the
+ # top MLP level
+ return fused_operation
- # return
- return trigger_and_forwards
+ raise NotImplementedError(f"Unknown fused op '{fused_op}'")
# trigger if either of the conditions are met
@@ -188,16 +189,10 @@ def build_lora_fused_ops(
def trigger_fused_ops(
module: torch.nn.Module,
attn_cls: Type,
- qkv_module_names: List[str] = None,
- o_module_name: str = "o_proj",
+ submodule_names: List[str],
):
- if qkv_module_names is None:
- qkv_module_names = ["q_proj", "k_proj", "v_proj"]
-
- _o = getattr(module, o_module_name)
- _qkv = [getattr(module, x) for x in qkv_module_names]
- # trigger on the attention layer
- return isinstance(module, attn_cls) and (
- all(_is_loralayer(x) for x in _qkv) or _is_loralayer(_o)
- )
+ # trigger if the module meets the attn class and the submodules
+ # are all loralayers
+ _mods = [getattr(module, x) for x in submodule_names]
+ return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods)
diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml
index c43a5adf..75f7279b 100644
--- a/sample-configurations/CONTENTS.yaml
+++ b/sample-configurations/CONTENTS.yaml
@@ -25,4 +25,10 @@ framework_configs:
plugins:
- accelerated-peft
- fused-ops-and-kernels
- filename: accelerated-peft-autogptq-foak-sample-configuration.yaml
\ No newline at end of file
+ filename: accelerated-peft-autogptq-foak-sample-configuration.yaml
+
+ - shortname: accelerated-peft-bnb-foak
+ plugins:
+ - accelerated-peft
+ - fused-ops-and-kernels
+ filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
\ No newline at end of file
diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
new file mode 100644
index 00000000..fcb9bb14
--- /dev/null
+++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
@@ -0,0 +1,44 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # PEFT-related acceleration
+ peft:
+
+ # quantization-releated acceleration
+ # e.g., kernels for quantized base weights
+ quantization:
+
+ # For loading BitsAndBytes quantized layers
+ # to serve as 4bit base-weights for LoRA PEFT-tuning.
+ # NOTE: currently AutoGPTQ is not properly integrated into huggingface /
+ # bitsandbytes, thus recommended quant_type to be either "nf4"
+ # or "fp4".
+ # bitsandbytes:
+ bitsandbytes:
+ quant_type: nf4
+
+ # If True, then no get_peft_model and prepare_model_for_kbit_training
+ # will be called.
+ no_peft_model: false
+ fused_ops_and_kernels:
+
+ # load unsloth optimizations for these 4bit base layer weights.
+ # currently only support "auto_gptq" and "bitsandbytes"
+ base_layer: bitsandbytes
+
+ # activate various unsloth optimizations
+ # NOTE: currently supports only all-or-nothing.
+
+ # fused kernels for lora linear layers
+ fused_lora: true
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rsm_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py
index 651b227a..f5ff4a54 100644
--- a/scripts/benchmarks/benchmark.py
+++ b/scripts/benchmarks/benchmark.py
@@ -1,5 +1,6 @@
# Standard
from itertools import product
+from time import sleep
from typing import Any, Callable, Dict, List, Tuple, Union
import argparse
import json
@@ -88,6 +89,7 @@
HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics"
RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes"
RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes"
+ERROR_MESSAGES = "error_messages"
def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
@@ -357,6 +359,17 @@ def __init__(
self.results_filename = os.path.join(self.save_dir, FILE_RESULTS)
self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM)
+ @property
+ def is_completed(self):
+ if not os.path.exists(self.results_filename):
+ return False
+ # otherwise open it and check for errors
+ with open(self.results_filename) as f:
+ results = json.load(f)
+
+ # return complete only if no errors
+ return not ERROR_MESSAGES in results
+
def run(
self,
run_cmd: str,
@@ -552,7 +565,7 @@ def write_result(self):
**self.get_experiment_final_metrics(),
}
else:
- other_results = {"error_messages": maybe_error_messages}
+ other_results = {ERROR_MESSAGES: maybe_error_messages}
# combine the final thing
save_result = {**save_result, **other_results}
@@ -781,6 +794,14 @@ def main(args):
log_memory_in_trainer=args.log_memory_hf,
)
):
+ # store pointer to file for future result retrival
+ experiment_stats[experiment.tag] = experiment.results_filename
+
+ if experiment.is_completed:
+ # if completed, dont proceed
+ sleep(0.1) # sleep a bit to allow the tqdm to update
+ continue
+
if experiment.num_gpus > 1:
prefix = COMMAND_ACCELERATE.format(
accelerate_config_path=args.accelerate_config,
@@ -806,10 +827,9 @@ def main(args):
log_nvidia_smi=args.log_nvidia_smi,
)
- # write results and store pointers to files
+ # write results
experiment.write_result()
experiment.write_shell_command()
- experiment_stats[experiment.tag] = experiment.results_filename
# 4. Consolidates the experiment results into a summary
for tag, path in experiment_stats.items():
diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py
index 1de9b2a5..51ba5642 100644
--- a/scripts/benchmarks/display_bench_results.py
+++ b/scripts/benchmarks/display_bench_results.py
@@ -1,18 +1,21 @@
# Standard
+from typing import List
import argparse
# First Party
# import this because of alot of internal contants
-from scripts.benchmarks.benchmark import gather_report, DIR_SAMP_CONFIGS
-from typing import List
+from scripts.benchmarks.benchmark import DIR_SAMP_CONFIGS, gather_report
-def main(*directories: str, output_filename: str = "results.csv", remove_columns: List[str] = None):
+
+def main(
+ *directories: str,
+ output_filename: str = "results.csv",
+ remove_columns: List[str] = None,
+ keep_columns: List[str] = None,
+):
"gather outputs from a list of directories and output to a csv"
- df, constant = gather_report(*directories, raw=False)
- # filter result columns to keep by the inverse of remove_columns
- if remove_columns:
- df = df[df.columns[~df.columns.isin(remove_columns)]]
+ df, constant = gather_report(directories, raw=False)
errors = []
try:
@@ -22,12 +25,25 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns
df = df.loc[df.error_messages.isna()]
except:
pass
+
+ # filter result columns to keep by the inverse of remove_columns
+ if remove_columns:
+ df = df[df.columns[~df.columns.isin(remove_columns)]]
+
+ # assume keep and remove are disjoint
+ kept = 0
+ if keep_columns:
+ for c in keep_columns:
+ if c in constant:
+ df[c] = constant[c]
+ kept += 1
+
df = df.reset_index(drop=True).drop("output_dir", axis=1)
df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False)
print("***************** Report Created ******************")
print(f"Total lines: '{len(df)}'")
print(f"Number columns included: '{len(df.columns)}'")
- print(f"Number columns excluded: '{len(constant)}'")
+ print(f"Number columns excluded: '{len(constant)-kept}'")
print(f"Excluding number of exceptions caught: '{len(errors)}'")
print(f"Written report to '{output_filename}'")
@@ -53,10 +69,16 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns
nargs="*",
help="list of columns to ignore from results.csv",
)
+ parser.add_argument(
+ "--keep_columns",
+ nargs="*",
+ help="list of columns to always include into results.csv",
+ )
args = parser.parse_args()
main(
- args.bench_outputs,
+ *args.bench_outputs,
output_filename=args.result_file,
remove_columns=args.remove_columns,
+ keep_columns=args.keep_columns,
)
diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv
index 4434d864..b83549a7 100644
--- a/scripts/benchmarks/refs/a100_80gb.csv
+++ b/scripts/benchmarks/refs/a100_80gb.csv
@@ -1,61 +1,82 @@
-epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,nvidia_mem_reserved,peak_torch_mem_alloc_in_bytes,peft_method,per_device_train_batch_size,r,target_modules,torch_mem_alloc_in_bytes,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
-0.04,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,77705.0,72971724288.0,,4,,,44004763136.0,0.9278398831685384,177.1092,0.678,0.169,2775.237
-0.04,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,44706.0,36762859520.0,,2,,,29521119232.0,0.8970902442932129,91.086,1.317,0.329,2698.11
-0.09,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,74383.0,72972117504.0,,8,,,44005156352.0,0.9879656155904134,322.458,0.744,0.093,3048.583
-0.09,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,53907.0,36763056128.0,,4,,,29521315840.0,0.9259945551554362,167.7727,1.431,0.179,2929.678
-,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,4,,,,,,,,
-,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79353.0,,,2,,,,,,,,
-,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,8,,,,,,,,
-,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79827.0,,,4,,,,,,,,
-,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,4,,,,,,,,
-,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,80830.0,,,2,,,,,,,,
-,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,8,,,,,,,,
-,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,80834.5,,,4,,,,,,,,
-0.04,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,29731.0,26108963328.0,lora,4,16,q_proj k_proj v_proj o_proj,15119590912.0,0.9096682230631511,136.624,0.878,0.22,3597.611
-0.04,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,18697.0,15123161088.0,lora,2,16,q_proj k_proj v_proj o_proj,7850391552.0,0.8918854713439941,82.0311,1.463,0.366,2995.936
-0.09,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,43195.0,37098695168.0,lora,8,16,q_proj k_proj v_proj o_proj,15119984128.0,0.962119706471761,270.6301,0.887,0.111,3632.412
-0.09,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,26235.0,21433753600.0,lora,4,16,q_proj k_proj v_proj o_proj,7850588160.0,0.9218235015869141,143.8184,1.669,0.209,3417.643
-,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,62617.0,57540387840.0,lora,2,16,q_proj k_proj v_proj o_proj,47311452160.0,0.9361546834309896,179.3128,0.669,0.167,1370.566
-,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-0.09,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,69848.0,64347637760.0,lora,4,16,q_proj k_proj v_proj o_proj,47311648768.0,0.9383139928181966,280.8919,0.854,0.107,1749.855
-,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80894.0,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80979.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,baseline-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,27023.0,22825932800.0,lora,4,16,q_proj k_proj v_proj o_proj,5368221184.0,0.9589527130126954,178.8061,0.671,0.168,2748.9
-0.04,True,baseline-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13530.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9154380798339844,87.3652,1.374,0.343,2813.02
-0.09,True,baseline-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,47145.0,40278956032.0,lora,8,16,q_proj k_proj v_proj o_proj,5368614400.0,0.9702634493509928,341.2286,0.703,0.088,2880.884
-0.09,True,baseline-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21502.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.914565912882487,149.9341,1.601,0.2,3278.241
-0.04,True,baseline-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,48313.0,46419968512.0,lora,4,16,q_proj k_proj v_proj o_proj,25726225920.0,0.9744932492574055,351.8623,0.341,0.085,1396.91
-0.04,True,baseline-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25549.0,21922782720.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9303209940592448,171.4299,0.7,0.175,1433.589
-0.09,True,baseline-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,69931.0,67089150464.0,lora,8,16,q_proj k_proj v_proj o_proj,25726619136.0,0.9745417594909668,629.837,0.381,0.048,1560.785
-0.09,True,baseline-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29384115200.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9310146331787109,300.5119,0.799,0.1,1635.609
-,True,baseline-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80893.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,baseline-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52634.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.0399916648864747,584.3145,0.205,0.051,420.595
-,True,baseline-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,79557.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,True,baseline-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80749.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,accelerated-peft-bnb,36,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,19931.0,15860019712.0,lora,4,16,q_proj k_proj v_proj o_proj,4843384320.0,0.9652111371358235,143.3569,0.837,0.209,3428.645
-0.04,True,accelerated-peft-bnb,37,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13497.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9277165730794271,86.4307,1.388,0.347,2843.435
-0.09,True,accelerated-peft-bnb,38,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,34355.0,26849751552.0,lora,8,16,q_proj k_proj v_proj o_proj,4843777536.0,0.9493892669677735,279.7156,0.858,0.107,3514.427
-0.09,True,accelerated-peft-bnb,39,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21479.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.9110882759094239,149.3914,1.607,0.201,3290.15
-0.04,True,accelerated-peft-bnb,40,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,38405.0,36218024448.0,lora,4,16,q_proj k_proj v_proj o_proj,25201389056.0,0.9741149584452311,278.5888,0.431,0.108,1764.32
-0.04,True,accelerated-peft-bnb,41,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25592.0,21906697728.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9300654411315918,172.7359,0.695,0.174,1422.75
-0.09,True,accelerated-peft-bnb,42,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,50875.0,47207756288.0,lora,8,16,q_proj k_proj v_proj o_proj,25201782272.0,0.9748441060384114,512.2298,0.469,0.059,1919.139
-0.09,True,accelerated-peft-bnb,43,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29369087488.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9301350593566895,287.6381,0.834,0.104,1708.814
-0.04,True,accelerated-peft-bnb,44,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,72829.0,68159977472.0,lora,4,16,q_proj k_proj v_proj o_proj,37346815488.0,1.118430455525716,1075.2044,0.112,0.028,457.141
-0.04,True,accelerated-peft-bnb,45,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52632.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.040946865081787,586.651,0.205,0.051,418.92
-,True,accelerated-peft-bnb,46,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80405.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,True,accelerated-peft-bnb,47,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80954.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,accelerated-peft-autogptq,48,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,20453.0,15890329088.0,lora,4,16,q_proj k_proj v_proj o_proj,4873693696.0,1.3805528958638509,151.0359,0.795,0.199,3254.326
-0.04,True,accelerated-peft-autogptq,49,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,17198.0,9952175616.0,lora,2,16,q_proj k_proj v_proj o_proj,3005709312.0,1.1706618309020995,87.4109,1.373,0.343,2811.548
-0.09,True,accelerated-peft-autogptq,50,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,34247.0,26880060928.0,lora,8,16,q_proj k_proj v_proj o_proj,4874086912.0,1.2741642634073893,282.6391,0.849,0.106,3478.076
-0.09,True,accelerated-peft-autogptq,51,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,24783.0,16262768128.0,lora,4,16,q_proj k_proj v_proj o_proj,3005905920.0,1.043952751159668,152.5473,1.573,0.197,3222.083
-0.04,True,accelerated-peft-autogptq,52,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,37461.0,35528093184.0,lora,4,16,q_proj k_proj v_proj o_proj,24511457792.0,0.9936613400777181,263.6066,0.455,0.114,1864.597
-0.04,True,accelerated-peft-autogptq,53,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,46641.0,25708175360.0,lora,2,16,q_proj k_proj v_proj o_proj,12788874240.0,0.9420519828796386,167.065,0.718,0.18,1471.045
-0.09,True,accelerated-peft-autogptq,54,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,49925.0,46517825024.0,lora,8,16,q_proj k_proj v_proj o_proj,24511851008.0,0.9855653127034505,498.9022,0.481,0.06,1970.406
-0.09,True,accelerated-peft-autogptq,55,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,52358.0,27739090432.0,lora,4,16,q_proj k_proj v_proj o_proj,12789070848.0,0.9389812151590983,281.8034,0.852,0.106,1744.195
-0.04,True,accelerated-peft-autogptq,56,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,71565.0,65895347200.0,lora,4,16,q_proj k_proj v_proj o_proj,36290144768.0,1.0755928039550782,1060.8387,0.113,0.028,463.331
-0.04,True,accelerated-peft-autogptq,57,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80387.0,45397678592.0,lora,2,16,q_proj k_proj v_proj o_proj,18649885696.0,1.0256956418355305,576.0422,0.208,0.052,426.635
-,True,accelerated-peft-autogptq,58,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,80293.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-0.08,True,accelerated-peft-autogptq,59,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80363.0,70667573760.0,lora,4,16,q_proj k_proj v_proj o_proj,18650082304.0,1.0266701062520345,1089.3291,0.22,0.028,451.214
+epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8676117706298828,584.6749,0.684,0.171,2802.241
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,12512.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8593511199951172,279.9917,1.429,0.357,2925.801
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.86837890625,1149.6017,0.696,0.087,2850.378
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,20435.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8526134586334229,496.2449,1.612,0.202,3301.596
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,47079.0,46427906560,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8966263771057129,1169.4078,0.342,0.086,1401.051
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,24609.0,21937980416,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8650046825408936,564.3075,0.709,0.177,1451.691
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,68071.0,67121147392,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8866284656524658,2118.0176,0.378,0.047,1547.107
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,32054.0,29375012352,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8636721038818359,959.452,0.834,0.104,1707.641
+,True,baseline-peft-bnb,2e-4,16,0.0,80631.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.14,True,baseline-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9462522315979004,1951.2462,0.205,0.051,419.834
+,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,baseline-peft-bnb,2e-4,16,0.0,80801.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.935322732925415,3737.7987,0.214,0.027,438.333
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8679532146453858,480.1165,0.833,0.208,3412.505
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12477.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8598325538635254,281.0553,1.423,0.356,2914.729
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8708646774291993,944.515,0.847,0.106,3469.294
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20417.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8568318557739257,498.8375,1.604,0.2,3284.436
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37321.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8979199028015137,923.4329,0.433,0.108,1774.249
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24783.0,21940224000,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8649028778076172,564.1011,0.709,0.177,1452.222
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49847.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8877867794036866,1717.1699,0.466,0.058,1908.256
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,31907.0,29336790016,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8623861598968506,952.2959,0.84,0.105,1720.474
+0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.999151840209961,3662.4376,0.109,0.027,447.352
+0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9392572689056397,1950.7659,0.205,0.051,419.938
+,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80866.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9258937835693359,3744.4001,0.214,0.027,437.56
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0217428588867188,477.2159,0.838,0.21,3433.247
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12056.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9701251029968262,278.7874,1.435,0.359,2938.44
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9569056987762451,941.1761,0.85,0.106,3481.601
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19530.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9303163433074951,494.3287,1.618,0.202,3314.394
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9736110210418701,411.3906,0.972,0.243,3982.589
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11506.0,9174099456,2405399552,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0141907215118409,248.8178,1.608,0.402,3292.368
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9668986797332764,809.2016,0.989,0.124,4049.424
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,18635.0,15282316800,2405596160,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.942121753692627,444.2322,1.801,0.225,3688.162
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36435.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9004004192352295,879.8344,0.455,0.114,1862.169
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22962.5,20697435648,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8698519325256348,537.8597,0.744,0.186,1523.074
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48941.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8974114608764648,1669.3163,0.479,0.06,1962.959
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29756.0,27484941824,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8667408466339112,924.2282,0.866,0.108,1772.722
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36613.0,33671981056,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9003233146667481,814.7613,0.491,0.123,2010.896
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22421.0,20108989952,12191160320,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.867002067565918,506.3203,0.79,0.198,1617.948
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49691.0,42742948864,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.897435302734375,1534.4874,0.521,0.065,2135.436
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,28865.0,26629788672,12191300608,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.866525583267212,877.2087,0.912,0.114,1867.742
+0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.99012770652771,3600.8607,0.111,0.028,455.002
+0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49455.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9539268207550049,1890.9021,0.212,0.053,433.232
+,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79283.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9549467945098877,3679.8651,0.217,0.027,445.234
+0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9903428840637207,3295.1075,0.121,0.03,497.222
+0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,46207.0,41579411968,15105330176,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9634347057342529,1740.6214,0.23,0.057,470.637
+,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80949.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,74507.0,66445605376,15105526784,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9590920734405518,3441.8985,0.232,0.029,476.016
+0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9002080440521241,558.4193,0.716,0.179,2933.996
+0.15,,none,2e-5,,,43695.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8854282188415528,302.5551,1.322,0.331,2707.606
+0.29,,none,2e-5,,,73761.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.0202219200134277,1085.5804,0.737,0.092,3018.478
+0.29,,none,2e-5,,,52923.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8920887660980225,561.8731,1.424,0.178,2915.961
+,,none,2e-5,,,79961.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,,
+,,none,2e-5,,,80925.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,,
+,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,,
+,,none,2e-5,,,80703.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,,
+,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,,
+,,none,2e-5,,,80922.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,,
+,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,,
+,,none,2e-5,,,80782.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,,
+0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8848505210876465,456.0676,0.877,0.219,3592.45
+0.15,,none,2e-4,16,0.0,17655.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8546714687347412,267.0472,1.498,0.374,3067.623
+0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0078722095489503,909.6399,0.879,0.11,3602.305
+0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8803257846832275,477.2486,1.676,0.21,3433.012
+,,none,2e-4,16,0.0,78871.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.15,,none,2e-4,16,0.0,61532.0,57531527168,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8628986740112304,545.0419,0.734,0.183,1503.004
+,,none,2e-4,16,0.0,80991.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.29,,none,2e-4,16,0.0,68811.0,64348470272,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8795901584625244,919.9512,0.87,0.109,1780.964
+,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80760.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80987.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704845142364502,417.5391,0.958,0.239,3923.944
+,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5527.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8942180156707764,818.5228,0.977,0.122,4003.309
+,True,accelerated-peft-bnb-foak,2e-4,16,0.0,5675.0,0,0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37301.0,35622334464,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.887912654876709,861.4969,0.464,0.116,1901.806
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49955.0,46024318976,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8887538051605225,1590.7501,0.503,0.063,2059.909
+0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0002326488494873,3357.4377,0.119,0.03,487.991
+,True,accelerated-peft-bnb-foak,2e-4,16,0.0,80303.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,True,accelerated-peft-bnb-foak,2e-4,16,0.0,21095.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
diff --git a/scripts/benchmarks/refs/l40_40gb.csv b/scripts/benchmarks/refs/l40_40gb.csv
deleted file mode 100644
index 2158c782..00000000
--- a/scripts/benchmarks/refs/l40_40gb.csv
+++ /dev/null
@@ -1,49 +0,0 @@
-acceleration_framework_config_file,epoch,error_messages,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second,training_data_path
-,,,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,0.03,,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,2,,,0.9020393848419189,102.4493,0.781,0.195,1599.23,benchmark_outputs/data/cache.json
-,,,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,0.06,,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,4,,,0.936076545715332,170.7722,0.937,0.117,1918.814,benchmark_outputs/data/cache.json
-,,,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,2,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,2,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,0.03,,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9326287746429444,120.2794,0.665,0.166,2724.324,benchmark_outputs/data/cache.json
-,0.03,,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9157441139221192,78.5825,1.018,0.255,2084.943,benchmark_outputs/data/cache.json
-,0.06,,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.0113807678222657,241.3246,0.663,0.083,2715.679,benchmark_outputs/data/cache.json
-,0.06,,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9433841228485107,133.2158,1.201,0.15,2459.768,benchmark_outputs/data/cache.json
-,,,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,36,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.6183419704437256,137.2634,0.583,0.146,2387.235,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,37,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.7251328945159912,73.906,1.082,0.271,2216.871,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,38,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.5904263019561768,272.1958,0.588,0.073,2407.679,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,39,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.515465259552002,138.6152,1.154,0.144,2363.954,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,40,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.012540912628174,227.0536,0.352,0.088,1443.183,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,41,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.0235525131225587,121.7118,0.657,0.164,1346.13,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,42,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,43,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.0152217864990234,229.6679,0.697,0.087,1426.756,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,44,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,45,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,46,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,47,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,0,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9979345798492432,130.1845,0.615,0.154,2517.044,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,1,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.942676591873169,69.8209,1.146,0.286,2346.575,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,2,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9919514656066895,259.8776,0.616,0.077,2521.802,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,3,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.933735466003418,133.6157,1.197,0.15,2452.406,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,4,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.015654945373535,218.3215,0.366,0.092,1500.906,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,5,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9546889305114746,173.2373,0.462,0.115,945.755,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,6,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,7,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9585415840148925,273.4507,0.585,0.073,1198.315,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,8,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,9,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,10,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,11,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml
index c935ac31..42f7c753 100644
--- a/scripts/benchmarks/scenarios.yaml
+++ b/scripts/benchmarks/scenarios.yaml
@@ -52,6 +52,7 @@ scenarios:
- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
+ - accelerated-peft-bnb-foak
arguments:
fp16: True
learning_rate: 2e-4
@@ -82,4 +83,4 @@ scenarios:
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- - 'TheBloke/Llama-2-70B-GPTQ'
\ No newline at end of file
+ - 'TheBloke/Llama-2-70B-GPTQ'
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index fd51d965..b3485e3c 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -143,6 +143,7 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4 = "bnb-nf4"
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"
KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak"
+KEY_BNB_NF4_FOAK = "bnb-nf4-foak"
CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
@@ -153,14 +154,18 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4_BASELINE: (
"plugins/accelerated-peft/configs/bnb.yaml",
[
- ("peft.quantization.bitsandbytes.quant_type", "nf4"),
- ("peft.quantization.bitsandbytes.no_peft_model", True),
+ ("peft.quantization.bitsandbytes.quant_type", "nf4"),
+ ("peft.quantization.bitsandbytes.no_peft_model", True),
],
),
KEY_AUTO_GPTQ_FOAK: (
"plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
[("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")],
),
+ KEY_BNB_NF4_FOAK: (
+ "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
+ [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")],
+ ),
}
# list of (tag, combi) tuples
@@ -173,8 +178,10 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)),
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
+ ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
]
+
# TODO: throw error if merge conflicts
def merge_configs(config_contents: List[Dict]):
"helper function to merge configuration contents."
@@ -183,10 +190,10 @@ def merge_configs(config_contents: List[Dict]):
def _merge(result: Dict, new_contents: Dict):
for k, v in new_contents.items():
if k not in result:
- # if k is not in result, it means v does not
+ # if k is not in result, it means v does not
# exist as a subtree under result, so we just do
# an assingment
- result[k] = v
+ result[k] = v
else:
# otherwise we call the merge
_merge(result[k], v)
diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh
index 798138bf..8f8a1f9b 100644
--- a/scripts/run_benchmarks.sh
+++ b/scripts/run_benchmarks.sh
@@ -58,10 +58,10 @@ if [ -n "$RESULT_DIR" ]; then
echo "Results dir $RESULT_DIR is not empty, but NO_OVERWRITE=true"
echo "If intending to overwrite please delete the folder manually"
echo "or do not set NO_OVERWRITE"
- exit 1
+ else
+ echo "Deleting $RESULT_DIR"
+ rm -rf $RESULT_DIR
fi
- echo "Deleting $RESULT_DIR"
- rm -rf $RESULT_DIR
fi
# tag on the directories
@@ -98,9 +98,11 @@ elif [ "$MEMORY_LOGGING" = "all" ]; then
fi
# dump out the environment
-echo "Creating $RESULT_DIR"
-mkdir -p $RESULT_DIR
-pip freeze > $PIP_REQUIREMENTS_FILE
+if [ ! "$NO_OVERWRITE" = "true" ]; then
+ echo "Creating $RESULT_DIR"
+ mkdir -p $RESULT_DIR
+ pip freeze > $PIP_REQUIREMENTS_FILE
+fi
# run the bench
python $WORKING_DIR/benchmark.py \
@@ -116,8 +118,10 @@ python $WORKING_DIR/benchmark.py \
# this will write to the BENCH_RESULT_FILE
# Remove the columns with values already represented by other metrics in the summary report
PYTHONPATH=. \
- python $WORKING_DIR/display_bench_results.py benchmark_outputs \
+ python $WORKING_DIR/display_bench_results.py $RESULT_DIR \
--result_file $BENCH_RESULT_FILE \
+ --keep_columns \
+ 'torch_dtype' \
--remove_columns \
'before_init_mem_cpu' \
'before_init_mem_gpu' \
@@ -129,5 +133,7 @@ PYTHONPATH=. \
'train_mem_cpu_peaked_delta' \
'train_mem_gpu_alloc_delta' \
'train_mem_gpu_peaked_delta' \
+ 'training_data_path' \
+ 'error_messages' \
'acceleration_framework_config_file'