diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml
index 2ec2bbca..f0bab9d6 100644
--- a/.github/workflows/format.yml
+++ b/.github/workflows/format.yml
@@ -27,7 +27,8 @@ jobs:
matrix:
plugin_name:
- "framework"
- # - "accelerated-peft" # enable later
+ - "accelerated-peft"
+ - "fused-ops-and-kernels"
steps:
- uses: actions/checkout@v4
diff --git a/README.md b/README.md
index c068f023..707c8662 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,6 @@ For example:
- GPTQ-LoRA: 22-44 % token throughput increase on 1 GPU as compared to using Hugging Face BNB QLoRA
- GPTQ-LoRA: Straightforward integration with multiple GPU as compared to using Hugging Face BNB QLoRA
-*Huggingface BNB QLoRA numbers taken with legacy approaches, but we are aware of [this issue](https://github.com/foundation-model-stack/fms-acceleration/issues/10) and will update our benches*.
*The above includes numbers using fusedOps-and-kernels and actual impl coming soon, see below*.
**This package is in BETA and is under development. Expect breaking changes!**
@@ -32,7 +31,7 @@ Plugin | Description | Depends | License | Status
--|--|--|--|--
[framework](./plugins/framework/README.md) | This acceleration framework for integration with huggingface trainers | | | Beta
[accelerated-peft](./plugins/accelerated-peft/README.md) | For PEFT-training, e.g., 4bit QLoRA. | Huggingface
AutoGPTQ | Apache 2.0
MIT | Beta
- fusedOps-and-kernels | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 with exclusions. | Coming Soon
+[fused-op-and-kernels](./plugins/fused-ops-and-kernels/README.md) | Fused LoRA and triton kernels (e.g., fast cross-entropy, rms, rope) | -- | Apache 2.0 [(contains extracted code)](./plugins/fused-ops-and-kernels/README.md#code-extracted-from-unsloth)| Beta
MOE-training-acceleration | [MegaBlocks](https://github.com/databricks/megablocks) inspired triton Kernels and acclerations for Mixture-of-Expert models | | Apache 2.0 | Coming Soon
## Usage with FMS HF Tuning
@@ -175,7 +174,6 @@ The benchmarks can be reproduced [with the provided scripts](./scripts/benchmark
See below CSV files for various results:
- [A100-80GB](./scripts/benchmarks/refs/a100_80gb.csv).
-- [L40-40GB](./scripts/benchmarks/refs/l40_40gb.csv).
### Code Architecture
diff --git a/plugins/accelerated-peft/.pylintrc b/plugins/accelerated-peft/.pylintrc
new file mode 100644
index 00000000..45da4212
--- /dev/null
+++ b/plugins/accelerated-peft/.pylintrc
@@ -0,0 +1,649 @@
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS,protobufs
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+ignore-paths=
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.9
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1100
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ # Added messages
+ use-symbolic-message-instead,
+ invalid-name,
+ missing-class-docstring,
+ missing-module-docstring,
+ missing-function-docstring,
+ consider-using-f-string,
+ inconsistent-return-statements,
+ no-member,
+ too-many-arguments,
+ too-many-locals,
+ too-many-branches,
+ too-many-statements,
+ cyclic-import,
+ too-few-public-methods,
+ protected-access,
+ fixme,
+ logging-format-interpolation,
+ logging-too-many-args,
+ attribute-defined-outside-init,
+ abstract-method,
+ pointless-statement,
+ wrong-import-order,
+ duplicate-code,
+ unbalanced-tuple-unpacking,
+ unused-argument
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=yes
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the 'python-enchant' package.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
index e3b2dc6d..913a6b7e 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/autogptq_utils.py
@@ -15,12 +15,81 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/
+# Standard
+from typing import Any, Callable, List
+import importlib
+
# Third Party
from peft import LoraConfig
from peft.tuners.lora.gptq import QuantLinear as LoraLinearGPTQ
-from typing import List, Callable
import torch
+# these parameters are to be patched for triton v2
+# consider making a map if patching more kernels
+PATCH_FOR_FSDP_TRITON_V2 = ["qweight", "qzeros"]
+
+
+# This function may be moved after merging
+# https://github.com/foundation-model-stack/fms-acceleration/pull/25
+def _patch_target_module(
+ to_patch: str,
+ replace_with: Any,
+ target_module: str = None,
+):
+ to_patch = to_patch.split(".")
+ assert len(to_patch) > 1, "must have an object to patch"
+
+ to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
+ to_patch = ".".join(to_patch)
+ source = importlib.import_module(to_patch)
+ original_obj = getattr(source, obj_name_to_patch)
+ setattr(source, obj_name_to_patch, replace_with)
+
+ if target_module is not None:
+ # reload and this should get the patched object
+ target_module = importlib.import_module(target_module)
+ importlib.reload(target_module)
+
+ # replace it
+ setattr(source, obj_name_to_patch, original_obj)
+
+
+def make_sure_no_tensor_in_meta_device(
+ model,
+ use_triton: bool,
+ desc_act: bool,
+ group_size: int,
+ bits: int,
+ disable_exllama: bool,
+ disable_exllamav2: bool,
+ use_marlin: bool = False,
+ use_tritonv2: bool = False,
+):
+ # Third Party
+ # guarded import
+ from auto_gptq.utils.import_utils import ( # pylint: disable=import-outside-toplevel,import-error
+ dynamically_import_QuantLinear,
+ )
+
+ QuantLinear = dynamically_import_QuantLinear(
+ use_triton,
+ desc_act,
+ group_size,
+ bits=bits,
+ disable_exllama=disable_exllama,
+ disable_exllamav2=disable_exllamav2,
+ use_marlin=use_marlin,
+ use_tritonv2=use_tritonv2,
+ )
+ for _, m in model.named_modules():
+ bias = getattr(m, "bias", None)
+ if bias:
+ if isinstance(m, QuantLinear) and bias.device == torch.device("meta"):
+ m.register_buffer(
+ "bias",
+ torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"),
+ )
+
def replace_module_peft(self, parent_module, child_name, new_module, old_module):
@@ -55,31 +124,48 @@ def create_new_module_peft(
# if module cannot be found, return None which results in a raise in the call-stack
return new_module
+
# consider to move this somewhere more general
def patch_forward_to_view_attributes_before_call(
old_forward: Callable,
- attribute_names: List[str], torch_dtype,
+ attribute_names: List[str],
+ torch_dtype: torch.dtype,
+ submodule_names: str = None,
+ is_method_forward: bool = True,
):
# patch old_forward to view attribtues to torch_dype
# before call
+ if submodule_names is None:
+ submodule_names = ""
+ if isinstance(submodule_names, str):
+ submodule_names = [submodule_names]
+
def _forward(self, *args, **kwargs):
- # perform a view on all these attributes
- for attr_name in attribute_names:
-
- # the view should be a passthrough
- # if attr.dtype == torch_dtype
- attr = getattr(self, attr_name)
-
- # perform view
- attr = attr.view(torch_dtype)
-
- try:
- setattr(self, attr_name, attr)
- except TypeError:
- # this means already have attr_name as a parameter, then
- # just assign this way
- self.__dict__[attr_name] = attr
-
- return old_forward(*args, **kwargs)
+
+ for sub_name in submodule_names:
+ mod = self.get_submodule(sub_name)
+
+ # perform a view on all these attributes
+ for attr_name in attribute_names:
+
+ # the view should be a passthrough
+ # if attr.dtype == torch_dtype
+ attr = getattr(mod, attr_name)
+
+ # perform view
+ attr = attr.view(torch_dtype)
+
+ try:
+ setattr(mod, attr_name, attr)
+ except TypeError:
+ # this means already have attr_name as a parameter, then
+ # just assign this way
+ mod.__dict__[attr_name] = attr
+
+ if is_method_forward:
+ # in this case, the self is already bound
+ return old_forward(*args, **kwargs)
+ return old_forward(self, *args, **kwargs)
+
return _forward
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
index fa6082ab..7928d9a9 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_autogptq.py
@@ -20,15 +20,16 @@
from functools import partial
from types import MethodType
from typing import Dict, Tuple
+import os
# Third Party
from fms_acceleration import AccelerationPlugin
from peft import LoraConfig, prepare_model_for_kbit_training
from peft.tuners.lora.model import LoraModel
-import torch.distributed
from transformers import AutoModelForCausalLM, TrainingArguments
+from transformers.modeling_utils import is_fsdp_enabled
import torch
-import os
+import torch.distributed
class AutoGPTQAccelerationPlugin(AccelerationPlugin):
@@ -48,12 +49,21 @@ def __init__(self, configurations: Dict[str, Dict]):
)
def model_loader(self, model_name: str, **kwargs):
-
# guarded imports
# Third Party
- from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
- from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear, QuantLinearFunction
- from .autogptq_utils import patch_forward_to_view_attributes_before_call
+ from auto_gptq import ( # pylint: disable=import-outside-toplevel,import-error
+ AutoGPTQForCausalLM,
+ BaseQuantizeConfig,
+ )
+ from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
+ QuantLinear,
+ )
+
+ # Local
+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
+ PATCH_FOR_FSDP_TRITON_V2,
+ patch_forward_to_view_attributes_before_call,
+ )
# Currently we allow only a quantized checkpoint to be loaded, we do not
# implement the quantization process here.
@@ -61,16 +71,18 @@ def model_loader(self, model_name: str, **kwargs):
# The quantization process is used to convert a non-quantized checkpoint
# (provided in model_name) into a quantized one. This entails
# 1. providing a BaseQuantizeConfig with the appropriate quantization settings
- # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm (may take time, e.g. hours)
+ # 2. calling BaseGPTQForCausalLM.quantize to run the quantization algorithm
+ # (may take time, e.g. hours)
# 3. calling BaseGPTQForCausalLM.save_pretrained to save a quantized checkpoint
#
# The reasons for not implementing the flow at this point are.
# 1. The quantization can take very long for large models. As such, it is more appropriate
- # to run it once outside of training, and save the checkpoint to be used for multiple runs.
+ # to run it once outside of training, and save the checkpoint to be used for multiple runs.
# 2. Requires some API changes to point to where the quantized checkpoint should be saved.
# Can be confusing to the user since it will be different from model_name
# NOTE: there will be a warning that can be ignored
- # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet, switching to cuda/cuda_old/triton backend."
+ # "WARNING - QuantLinear with the exllama backend not does support the trainable mode yet,
+ # switching to cuda/cuda_old/triton backend."
# assume model_name points to a quantized checkpoint. Thus we load the quantization
# config directly from the checkpoint.
quantize_config = BaseQuantizeConfig.from_pretrained(model_name)
@@ -80,35 +92,49 @@ def model_loader(self, model_name: str, **kwargs):
low_cpu_mem_usage = kwargs.get("low_cpu_mem_usage")
attn_implementation = kwargs.get("attn_implementation")
- if low_cpu_mem_usage:
- # Note that low_cpu_mem_usage is typically set to transformers.modeling_utils.is_fsdp_enabled.
- # e.g., https://github.com/huggingface/transformers/blob/a98c41798cf6ed99e1ff17e3792d6e06a2ff2ff3/src/transformers/modeling_utils.py#L2989-L2990
- # but not doing that now as AutoGPTQ will call make_sure_no_tensor_in_meta_device
- # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_base.py#L982C17-L982C51
- # which does not properly check if a QuantLayer has a bias set or not,
- # https://github.com/AutoGPTQ/AutoGPTQ/blob/ea829c7bbe83561c2b1de26795b6592992373ef7/auto_gptq/modeling/_utils.py#L514
- raise ValueError(
- "low_cpu_mem_usage set to True. This may raise error if model has no bias, "
- "due to AutoGPTQ bug. Not supporting at the moment."
- )
-
# there are some kwargs that we wont be passed to AutoModel, so we need
# to patch them in
_old_from_config = AutoModelForCausalLM.from_config
- # Standard
- from functools import partial
_from_config = partial(
_old_from_config, attn_implementation=attn_implementation
)
AutoModelForCausalLM.from_config = _from_config # patch
+ # this is a HF method that checks if the low_cpu_mem mode is enabled
+ # via HF accelerate
+ if is_fsdp_enabled():
+ # Local
+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
+ _patch_target_module,
+ make_sure_no_tensor_in_meta_device,
+ )
+
+ # We patch `make_sure_no_tensor_in_meta_device`
+ # from autogptq to avoid errors on models without bias
+ _patch_target_module(
+ to_patch="auto_gptq.modeling._utils.make_sure_no_tensor_in_meta_device",
+ replace_with=make_sure_no_tensor_in_meta_device,
+ target_module="auto_gptq.modeling._base",
+ )
+ low_cpu_mem_usage = True
+
# NOTE: need to set the device map as below as we want to use AutoGPTQ for training.
- # device_map is for inference only https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
- # Thus we set it as below to effectively disable it.
- device_map = (
- {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
- )
+ # device_map is for inference only
+ # https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference
+ # For low_cpu_mem_usage = True, we have to set the device map to load checkpoints to "cpu"
+ # to avoid gpu consumption before train
+ # This approach will divert consumption to cpu memory,
+ # a better approach would be to load the checkpoints to meta device
+ # QLoRA is currently implemented by the former approach and will encounter the same issue.
+ # see https://github.com/huggingface/transformers/pull/25107#issuecomment-2134833262
+ device_map = {
+ "": (
+ (torch.cuda.current_device() if not low_cpu_mem_usage else "cpu")
+ if torch.cuda.is_available()
+ else None
+ )
+ }
# currently only enable triton_v2, because the triton kernels are the only ones
# that have backwards
@@ -119,14 +145,14 @@ def model_loader(self, model_name: str, **kwargs):
low_cpu_mem_usage=low_cpu_mem_usage,
use_marlin=False, # disable, cannot be used for training (no forward+backward)
disable_exllama=True, # disable, cannot be used for training (no backward)
- warmup_triton=False, # disable for now, because it will try to run the warmup while on CPU
+ warmup_triton=False, # disable for now as it will try to run the warmup while on CPU
use_tritonv2=True,
trainable=True, # only support trainable mode
device_map=device_map,
)
# https://github.com/foundation-model-stack/fms-acceleration/pull/15
- # if FSDP distributed need to convert the AutoGPTQ model's
+ # if FSDP distributed need to convert the AutoGPTQ model's
# parameters (in tensors) to parameters. Also need to
# store the int32 tensors in a float type
@@ -139,9 +165,6 @@ def model_loader(self, model_name: str, **kwargs):
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
- # these parameters are to be patched for triton v2
- # consider making a map if patching more kernels
- PATCH_FOR_FSDP_TRITON_V2 = ['qweight', 'qzeros']
# patch all the QuantLinear base layers
for mod in model.modules():
@@ -151,14 +174,17 @@ def model_loader(self, model_name: str, **kwargs):
# so FSDP can shard them
for attr_name in PATCH_FOR_FSDP_TRITON_V2:
attr = getattr(mod, attr_name)
- attr = torch.nn.Parameter(attr.view(torch_dtype), requires_grad=False)
+ attr = torch.nn.Parameter(
+ attr.view(torch_dtype), requires_grad=False
+ )
setattr(mod, attr_name, attr)
- # this patches the forward to convert them back to original
+ # this patches the forward to convert them back to original
# type (i.e. int32) before the function call into the kernels
_forward = patch_forward_to_view_attributes_before_call(
- mod.forward, attribute_names=PATCH_FOR_FSDP_TRITON_V2,
- torch_dtype=torch.int32, # patch it back to
+ mod.forward,
+ attribute_names=PATCH_FOR_FSDP_TRITON_V2,
+ torch_dtype=torch.int32, # patch it back to
)
mod.forward = MethodType(_forward, mod)
@@ -193,11 +219,19 @@ def augmentation(
):
# guarded imports
# Third Party
- from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
- from auto_gptq.utils.peft_utils import GPTQLoraModel, get_gptq_peft_model
+ from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import ( # pylint: disable=import-outside-toplevel,import-error
+ QuantLinear,
+ )
+ from auto_gptq.utils.peft_utils import ( # pylint: disable=import-outside-toplevel,import-error
+ GPTQLoraModel,
+ get_gptq_peft_model,
+ )
# Local
- from .autogptq_utils import create_new_module_peft, replace_module_peft
+ from .autogptq_utils import ( # pylint: disable=import-outside-toplevel
+ create_new_module_peft,
+ replace_module_peft,
+ )
(peft_config,) = modifiable_args # unpack modifiable args
diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
index dfd5fbc8..6e71d11a 100644
--- a/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
+++ b/plugins/accelerated-peft/src/fms_acceleration_peft/framework_plugin_bnb.py
@@ -23,7 +23,7 @@
# Third Party
from fms_acceleration import AccelerationPlugin
-from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
+from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
import torch
@@ -41,7 +41,7 @@ def _prepare_model_for_kbit_training(
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
- for name, param in model.named_parameters():
+ for _, param in model.named_parameters():
# freeze base model's layers
param.requires_grad = False
@@ -56,22 +56,24 @@ def _prepare_model_for_kbit_training(
model.enable_input_require_grads()
else:
- def make_inputs_require_grad(module, input, output):
+ def make_inputs_require_grad(_module, _input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(
make_inputs_require_grad
)
- # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
+ # To support older transformers versions,
+ # check if the model supports gradient_checkpointing_kwargs
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
inspect.signature(model.gradient_checkpointing_enable).parameters
)
if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
warnings.warn(
- "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
- " if you want to use that feature, please upgrade to the latest version of transformers.",
+ "gradient_checkpointing_kwargs is not supported in this version of transformers.",
+ "The passed kwargs will be ignored. if you want to use that feature,",
+ "please upgrade to the latest version of transformers.",
FutureWarning,
)
@@ -124,16 +126,14 @@ def model_loader(self, model_name: str, **kwargs):
"If running in FSDP, this is probably because accelerate is not used. "
"This will most probably result in error."
)
- elif (
- world_size == 1
- and self._no_peft_model == True
- ):
+ elif world_size == 1 and self._no_peft_model is True:
warnings.warn(
"""Running on single device and setting plugin config `no_peft_model` as `True`
- PEFT preparation will be managed by SFTTrainer and will cause a slowdown in training speed
- due to extraneous dtype casting when SFTTrainer prepares the model using
+ PEFT preparation will be managed by SFTTrainer and
+ will cause a slowdown in training speed due to
+ extraneous dtype casting when SFTTrainer prepares the model using
https://github.com/huggingface/trl/blob/e90e8d91d2265e484f229c45a5eb8982f94a2936/trl/trainer/sft_trainer.py#L210"""
- )
+ )
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
diff --git a/plugins/accelerated-peft/tests/__init__.py b/plugins/accelerated-peft/tests/__init__.py
new file mode 100644
index 00000000..38a9531e
--- /dev/null
+++ b/plugins/accelerated-peft/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/accelerated-peft/tests/test_peft_plugins.py b/plugins/accelerated-peft/tests/test_peft_plugins.py
index 42404ddc..bb0621e5 100644
--- a/plugins/accelerated-peft/tests/test_peft_plugins.py
+++ b/plugins/accelerated-peft/tests/test_peft_plugins.py
@@ -134,7 +134,7 @@ def test_configure_bnb_plugin():
require_packages_check=False,
):
# check flags and callbacks
- assert (not correct_value)==framework.requires_agumentation
+ assert (not correct_value) == framework.requires_agumentation
# attempt to activate plugin with configuration pointing to wrong path
# - raise with message that no plugins can be configured
diff --git a/plugins/accelerated-peft/tox.ini b/plugins/accelerated-peft/tox.ini
index b79d0691..eb53996e 100644
--- a/plugins/accelerated-peft/tox.ini
+++ b/plugins/accelerated-peft/tox.ini
@@ -4,23 +4,27 @@ envlist = py, lint
[testenv]
deps =
pytest>=7
-
# for the tests, we need to install the deps ourselves
# as the package will install the github version
-e {toxinidir}/../framework
-skip_install = true
+# set skip package installation as it will install package pyproject.toml before deps, will throw error when AutoGPTQ needs torch
+skip_install = true
commands =
-
# install the current package
pip install --no-deps {toxinidir}
-
pytest {posargs:tests}
-[testenv:lint]
+[testenv:lint]
description = run linters
deps =
+ -e {toxinidir}/../framework
+ pytest>=7
pylint>=2.16.2,<=3.1.0
-commands = pylint src tests
+commands =
+ # installs package without autogptq dep to lint without CUDA,
+ # autogptq pylint import-errors are disabled inline
+ pip install --no-deps {toxinidir}
+ pylint src tests
allowlist_externals = pylint
[testenv:fmt]
diff --git a/plugins/framework/README.md b/plugins/framework/README.md
index 2fe9cba0..2794f4fb 100644
--- a/plugins/framework/README.md
+++ b/plugins/framework/README.md
@@ -88,3 +88,53 @@ Each [package](#packages) in this monorepo:
- When instantiating `fms_acceleration.AccelerationFramework`, it internally parses through the configuration stanzas. For plugins that are installed, it will instantiate them; for those that are not, it will simply *passthrough*.
- `AccelerationFramework` will manage plugins transparently for user. User only needs to call `AccelerationFramework.model_loader` and `AccelerationFramework.augmentation`.
+
+## Adding New Plugins
+
+To add new plugins:
+
+1. Create an appropriately `pip`-packaged plugin in `plugins`; the package needs to be named like `fms-acceleration-` .
+2. For `framework` to properly load and manage plugin, add the package `` to [constants.py](./src/fms_acceleration/constants.py):
+
+ ```python
+ PLUGINS = [
+ "peft",
+ "unsloth",
+ "",
+ ]
+ ```
+3. Create a sample template YAML file inside the `/configs` to demonstrate how to configure the plugin. As an example, reference the [sample config for accelerated peft](../accelerated-peft/configs/autogptq.yaml).
+4. Update [generate_sample_configurations.py](../../scripts/generate_sample_configurations.py) and run `tox -e gen-configs` on the top level directory to generate the sample configurations.
+
+ ```python
+ KEY_AUTO_GPTQ = "auto_gptq"
+ KEY_BNB_NF4 = "bnb-nf4"
+ PLUGIN_A = ""
+
+ CONFIGURATIONS = {
+ KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
+ KEY_BNB_NF4: (
+ "plugins/accelerated-peft/configs/bnb.yaml",
+ [("peft.quantization.bitsandbytes.quant_type", "nf4")],
+ ),
+ PLUGIN_A: (
+ "plugins//configs/plugin_config.yaml",
+ [
+ (<1st field in plugin_config.yaml>, ),
+ (<2nd field in plugin_config.yaml>, ),
+ ]
+ )
+ }
+
+ # Passing a tuple of configuration keys will combine the templates together
+ COMBINATIONS = [
+ ("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)),
+ ("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)),
+ (<"combined name with your plugin">), (KEY_AUTO_GPTQ, PLUGIN_A)
+ (<"combined name with your plugin">), (KEY_BNB_NF4, PLUGIN_A)
+ ]
+ ```
+5. After sample configuration is generated by `tox -e gen-configs`, update [CONTENTS.yaml](../../sample-configurations/CONTENTS.yaml) with the shortname and the configuration fullpath.
+6. Update [scenarios YAML](../../scripts/benchmarks/scenarios.yaml) to configure benchmark test scenarios that will be triggered when running `tox -e run-benches` on the top level directory.
+7. Update the [top-level tox.ini](../../tox.ini) to install the plugin for the `run-benches`.
+
diff --git a/plugins/framework/src/fms_acceleration/constants.py b/plugins/framework/src/fms_acceleration/constants.py
index 7fe2688a..9b5fa9cc 100644
--- a/plugins/framework/src/fms_acceleration/constants.py
+++ b/plugins/framework/src/fms_acceleration/constants.py
@@ -21,4 +21,5 @@
PLUGINS = [
"peft",
+ "foak"
]
diff --git a/plugins/framework/src/fms_acceleration/framework.py b/plugins/framework/src/fms_acceleration/framework.py
index 6d545ac7..ff83dd0c 100644
--- a/plugins/framework/src/fms_acceleration/framework.py
+++ b/plugins/framework/src/fms_acceleration/framework.py
@@ -179,11 +179,12 @@ def get_callbacks_and_ready_for_train(
self, model: torch.nn.Module = None, accelerator: Accelerator = None
):
# show the initialized message
- log_initialization_message(
- {x for x, _ in self.active_plugins},
- PLUGIN_REGISTRATIONS,
- logging_func=logger.info,
- )
+ if accelerator is not None and accelerator.is_main_process:
+ log_initialization_message(
+ {x for x, _ in self.active_plugins},
+ PLUGIN_REGISTRATIONS,
+ logging_func=logger.info,
+ )
cbks = []
for _, plugin in self.active_plugins:
diff --git a/plugins/fused-ops-and-kernels/.isort.cfg b/plugins/fused-ops-and-kernels/.isort.cfg
new file mode 100644
index 00000000..4aa62fac
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/.isort.cfg
@@ -0,0 +1,13 @@
+[settings]
+profile=black
+from_first=true
+import_heading_future=Future
+import_heading_stdlib=Standard
+import_heading_thirdparty=Third Party
+import_heading_firstparty=First Party
+import_heading_localfolder=Local
+known_firstparty=
+known_localfolder=tuning
+
+# skip code imported from unsloth
+skip_glob=**/unsloth*/**
diff --git a/plugins/fused-ops-and-kernels/.pylintrc b/plugins/fused-ops-and-kernels/.pylintrc
new file mode 100644
index 00000000..31cb902c
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/.pylintrc
@@ -0,0 +1,650 @@
+[MAIN]
+
+# Analyse import fallback blocks. This can be used to support both Python 2 and
+# 3 compatible code, which means that the block might have code that exists
+# only in one or another interpreter, leading to false positives when analysed.
+analyse-fallback-blocks=no
+
+# Clear in-memory caches upon conclusion of linting. Useful if running pylint
+# in a server-like mode.
+clear-cache-post-run=no
+
+# Load and enable all available extensions. Use --list-extensions to see a list
+# all available extensions.
+#enable-all-extensions=
+
+# In error mode, messages with a category besides ERROR or FATAL are
+# suppressed, and no reports are done by default. Error mode is compatible with
+# disabling specific errors.
+#errors-only=
+
+# Always return a 0 (non-error) status code, even if lint errors are found.
+# This is primarily useful in continuous integration scripts.
+#exit-zero=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code.
+extension-pkg-allow-list=
+
+# A comma-separated list of package or module names from where C extensions may
+# be loaded. Extensions are loading into the active Python interpreter and may
+# run arbitrary code. (This is an alternative name to extension-pkg-allow-list
+# for backward compatibility.)
+extension-pkg-whitelist=
+
+# Return non-zero exit code if any of these messages/categories are detected,
+# even if score is above --fail-under value. Syntax same as enable. Messages
+# specified are enabled, while categories only check already-enabled messages.
+fail-on=
+
+# Specify a score threshold under which the program will exit with error.
+fail-under=10
+
+# Interpret the stdin as a python script, whose filename needs to be passed as
+# the module_or_package argument.
+#from-stdin=
+
+# Files or directories to be skipped. They should be base names, not paths.
+ignore=CVS,protobufs
+
+# Add files or directories matching the regular expressions patterns to the
+# ignore-list. The regex matches against paths and can be in Posix or Windows
+# format. Because '\\' represents the directory delimiter on Windows systems,
+# it can't be used as an escape character.
+# NOTE: do not lint code imported from unsloth
+ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth*
+
+# Files or directories matching the regular expression patterns are skipped.
+# The regex matches against base names, not paths. The default value ignores
+# Emacs file locks
+ignore-patterns=^\.#
+
+# List of module names for which member attributes should not be checked
+# (useful for modules/projects where namespaces are manipulated during runtime
+# and thus existing member attributes cannot be deduced by static analysis). It
+# supports qualified module names, as well as Unix pattern matching.
+ignored-modules=
+
+# Python code to execute, usually for sys.path manipulation such as
+# pygtk.require().
+#init-hook=
+
+# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
+# number of processors available to use, and will cap the count on Windows to
+# avoid hangs.
+jobs=1
+
+# Control the amount of potential inferred values when inferring a single
+# object. This can help the performance when dealing with large functions or
+# complex, nested conditions.
+limit-inference-results=100
+
+# List of plugins (as comma separated values of python module names) to load,
+# usually to register additional checkers.
+load-plugins=
+
+# Pickle collected data for later comparisons.
+persistent=yes
+
+# Minimum Python version to use for version dependent checks. Will default to
+# the version used to run pylint.
+py-version=3.9
+
+# Discover python modules and packages in the file system subtree.
+recursive=no
+
+# When enabled, pylint would attempt to guess common misconfiguration and emit
+# user-friendly hints instead of false-positive error messages.
+suggestion-mode=yes
+
+# Allow loading of arbitrary C extensions. Extensions are imported into the
+# active Python interpreter and may run arbitrary code.
+unsafe-load-any-extension=no
+
+# In verbose mode, extra non-checker-related info will be displayed.
+#verbose=
+
+
+[BASIC]
+
+# Naming style matching correct argument names.
+argument-naming-style=snake_case
+
+# Regular expression matching correct argument names. Overrides argument-
+# naming-style. If left empty, argument names will be checked with the set
+# naming style.
+#argument-rgx=
+
+# Naming style matching correct attribute names.
+attr-naming-style=snake_case
+
+# Regular expression matching correct attribute names. Overrides attr-naming-
+# style. If left empty, attribute names will be checked with the set naming
+# style.
+#attr-rgx=
+
+# Bad variable names which should always be refused, separated by a comma.
+bad-names=foo,
+ bar,
+ baz,
+ toto,
+ tutu,
+ tata
+
+# Bad variable names regexes, separated by a comma. If names match any regex,
+# they will always be refused
+bad-names-rgxs=
+
+# Naming style matching correct class attribute names.
+class-attribute-naming-style=any
+
+# Regular expression matching correct class attribute names. Overrides class-
+# attribute-naming-style. If left empty, class attribute names will be checked
+# with the set naming style.
+#class-attribute-rgx=
+
+# Naming style matching correct class constant names.
+class-const-naming-style=UPPER_CASE
+
+# Regular expression matching correct class constant names. Overrides class-
+# const-naming-style. If left empty, class constant names will be checked with
+# the set naming style.
+#class-const-rgx=
+
+# Naming style matching correct class names.
+class-naming-style=PascalCase
+
+# Regular expression matching correct class names. Overrides class-naming-
+# style. If left empty, class names will be checked with the set naming style.
+#class-rgx=
+
+# Naming style matching correct constant names.
+const-naming-style=UPPER_CASE
+
+# Regular expression matching correct constant names. Overrides const-naming-
+# style. If left empty, constant names will be checked with the set naming
+# style.
+#const-rgx=
+
+# Minimum line length for functions/classes that require docstrings, shorter
+# ones are exempt.
+docstring-min-length=-1
+
+# Naming style matching correct function names.
+function-naming-style=snake_case
+
+# Regular expression matching correct function names. Overrides function-
+# naming-style. If left empty, function names will be checked with the set
+# naming style.
+#function-rgx=
+
+# Good variable names which should always be accepted, separated by a comma.
+good-names=i,
+ j,
+ k,
+ ex,
+ Run,
+ _
+
+# Good variable names regexes, separated by a comma. If names match any regex,
+# they will always be accepted
+good-names-rgxs=
+
+# Include a hint for the correct naming format with invalid-name.
+include-naming-hint=no
+
+# Naming style matching correct inline iteration names.
+inlinevar-naming-style=any
+
+# Regular expression matching correct inline iteration names. Overrides
+# inlinevar-naming-style. If left empty, inline iteration names will be checked
+# with the set naming style.
+#inlinevar-rgx=
+
+# Naming style matching correct method names.
+method-naming-style=snake_case
+
+# Regular expression matching correct method names. Overrides method-naming-
+# style. If left empty, method names will be checked with the set naming style.
+#method-rgx=
+
+# Naming style matching correct module names.
+module-naming-style=snake_case
+
+# Regular expression matching correct module names. Overrides module-naming-
+# style. If left empty, module names will be checked with the set naming style.
+#module-rgx=
+
+# Colon-delimited sets of names that determine each other's naming style when
+# the name regexes allow several styles.
+name-group=
+
+# Regular expression which should only match function or class names that do
+# not require a docstring.
+no-docstring-rgx=^_
+
+# List of decorators that produce properties, such as abc.abstractproperty. Add
+# to this list to register other decorators that produce valid properties.
+# These decorators are taken in consideration only for invalid-name.
+property-classes=abc.abstractproperty
+
+# Regular expression matching correct type variable names. If left empty, type
+# variable names will be checked with the set naming style.
+#typevar-rgx=
+
+# Naming style matching correct variable names.
+variable-naming-style=snake_case
+
+# Regular expression matching correct variable names. Overrides variable-
+# naming-style. If left empty, variable names will be checked with the set
+# naming style.
+#variable-rgx=
+
+
+[CLASSES]
+
+# Warn about protected attribute access inside special methods
+check-protected-access-in-special-methods=no
+
+# List of method names used to declare (i.e. assign) instance attributes.
+defining-attr-methods=__init__,
+ __new__,
+ setUp,
+ __post_init__
+
+# List of member names, which should be excluded from the protected access
+# warning.
+exclude-protected=_asdict,
+ _fields,
+ _replace,
+ _source,
+ _make
+
+# List of valid names for the first argument in a class method.
+valid-classmethod-first-arg=cls
+
+# List of valid names for the first argument in a metaclass class method.
+valid-metaclass-classmethod-first-arg=mcs
+
+
+[DESIGN]
+
+# List of regular expressions of class ancestor names to ignore when counting
+# public methods (see R0903)
+exclude-too-few-public-methods=
+
+# List of qualified class names to ignore when counting class parents (see
+# R0901)
+ignored-parents=
+
+# Maximum number of arguments for function / method.
+max-args=5
+
+# Maximum number of attributes for a class (see R0902).
+max-attributes=7
+
+# Maximum number of boolean expressions in an if statement (see R0916).
+max-bool-expr=5
+
+# Maximum number of branch for function / method body.
+max-branches=12
+
+# Maximum number of locals for function / method body.
+max-locals=15
+
+# Maximum number of parents for a class (see R0901).
+max-parents=7
+
+# Maximum number of public methods for a class (see R0904).
+max-public-methods=20
+
+# Maximum number of return / yield for function / method body.
+max-returns=6
+
+# Maximum number of statements in function / method body.
+max-statements=50
+
+# Minimum number of public methods for a class (see R0903).
+min-public-methods=2
+
+
+[EXCEPTIONS]
+
+# Exceptions that will emit a warning when caught.
+overgeneral-exceptions=builtins.BaseException,builtins.Exception
+
+
+[FORMAT]
+
+# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
+expected-line-ending-format=
+
+# Regexp for a line that is allowed to be longer than the limit.
+ignore-long-lines=^\s*(# )??$
+
+# Number of spaces of indent required inside a hanging or continued line.
+indent-after-paren=4
+
+# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
+# tab).
+indent-string=' '
+
+# Maximum number of characters on a single line.
+max-line-length=100
+
+# Maximum number of lines in a module.
+max-module-lines=1100
+
+# Allow the body of a class to be on the same line as the declaration if body
+# contains single statement.
+single-line-class-stmt=no
+
+# Allow the body of an if to be on the same line as the test if there is no
+# else.
+single-line-if-stmt=no
+
+
+[IMPORTS]
+
+# List of modules that can be imported at any level, not just the top level
+# one.
+allow-any-import-level=
+
+# Allow explicit reexports by alias from a package __init__.
+allow-reexport-from-package=no
+
+# Allow wildcard imports from modules that define __all__.
+allow-wildcard-with-all=no
+
+# Deprecated modules which should not be used, separated by a comma.
+deprecated-modules=
+
+# Output a graph (.gv or any supported image format) of external dependencies
+# to the given file (report RP0402 must not be disabled).
+ext-import-graph=
+
+# Output a graph (.gv or any supported image format) of all (i.e. internal and
+# external) dependencies to the given file (report RP0402 must not be
+# disabled).
+import-graph=
+
+# Output a graph (.gv or any supported image format) of internal dependencies
+# to the given file (report RP0402 must not be disabled).
+int-import-graph=
+
+# Force import order to recognize a module as part of the standard
+# compatibility libraries.
+known-standard-library=
+
+# Force import order to recognize a module as part of a third party library.
+known-third-party=enchant
+
+# Couples of modules and preferred modules, separated by a comma.
+preferred-modules=
+
+
+[LOGGING]
+
+# The type of string formatting that logging methods do. `old` means using %
+# formatting, `new` is for `{}` formatting.
+logging-format-style=old
+
+# Logging modules to check that the string format arguments are in logging
+# function parameter format.
+logging-modules=logging
+
+
+[MESSAGES CONTROL]
+
+# Only show warnings with the listed confidence levels. Leave empty to show
+# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE,
+# UNDEFINED.
+confidence=HIGH,
+ CONTROL_FLOW,
+ INFERENCE,
+ INFERENCE_FAILURE,
+ UNDEFINED
+
+# Disable the message, report, category or checker with the given id(s). You
+# can either give multiple identifiers separated by comma (,) or put this
+# option multiple times (only on the command line, not in the configuration
+# file where it should appear only once). You can also use "--disable=all" to
+# disable everything first and then re-enable specific checks. For example, if
+# you want to run only the similarities checker, you can use "--disable=all
+# --enable=similarities". If you want to run only the classes checker, but have
+# no Warning level messages displayed, use "--disable=all --enable=classes
+# --disable=W".
+disable=raw-checker-failed,
+ bad-inline-option,
+ locally-disabled,
+ file-ignored,
+ suppressed-message,
+ useless-suppression,
+ deprecated-pragma,
+ # Added messages
+ use-symbolic-message-instead,
+ invalid-name,
+ missing-class-docstring,
+ missing-module-docstring,
+ missing-function-docstring,
+ consider-using-f-string,
+ inconsistent-return-statements,
+ no-member,
+ too-many-arguments,
+ too-many-locals,
+ too-many-branches,
+ too-many-statements,
+ cyclic-import,
+ too-few-public-methods,
+ protected-access,
+ fixme,
+ logging-format-interpolation,
+ logging-too-many-args,
+ attribute-defined-outside-init,
+ abstract-method,
+ pointless-statement,
+ wrong-import-order,
+ duplicate-code,
+ unbalanced-tuple-unpacking,
+ unused-argument
+
+# Enable the message, report, category or checker with the given id(s). You can
+# either give multiple identifier separated by comma (,) or put this option
+# multiple time (only on the command line, not in the configuration file where
+# it should appear only once). See also the "--disable" option for examples.
+enable=c-extension-no-member
+
+
+[METHOD_ARGS]
+
+# List of qualified names (i.e., library.method) which require a timeout
+# parameter e.g. 'requests.api.get,requests.api.post'
+timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request
+
+
+[MISCELLANEOUS]
+
+# List of note tags to take in consideration, separated by a comma.
+notes=FIXME,
+ XXX,
+ TODO
+
+# Regular expression of note tags to take in consideration.
+notes-rgx=
+
+
+[REFACTORING]
+
+# Maximum number of nested blocks for function / method body
+max-nested-blocks=5
+
+# Complete name of functions that never returns. When checking for
+# inconsistent-return-statements if a never returning function is called then
+# it will be considered as an explicit return statement and no message will be
+# printed.
+never-returning-functions=sys.exit,argparse.parse_error
+
+
+[REPORTS]
+
+# Python expression which should return a score less than or equal to 10. You
+# have access to the variables 'fatal', 'error', 'warning', 'refactor',
+# 'convention', and 'info' which contain the number of messages in each
+# category, as well as 'statement' which is the total number of statements
+# analyzed. This score is used by the global evaluation report (RP0004).
+evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
+
+# Template used to display messages. This is a python new-style format string
+# used to format the message information. See doc for all details.
+msg-template=
+
+# Set the output format. Available formats are text, parseable, colorized, json
+# and msvs (visual studio). You can also give a reporter class, e.g.
+# mypackage.mymodule.MyReporterClass.
+output-format=text
+
+# Tells whether to display a full report or only the messages.
+reports=yes
+
+# Activate the evaluation score.
+score=yes
+
+
+[SIMILARITIES]
+
+# Comments are removed from the similarity computation
+ignore-comments=yes
+
+# Docstrings are removed from the similarity computation
+ignore-docstrings=yes
+
+# Imports are removed from the similarity computation
+ignore-imports=yes
+
+# Signatures are removed from the similarity computation
+ignore-signatures=yes
+
+# Minimum lines number of a similarity.
+min-similarity-lines=4
+
+
+[SPELLING]
+
+# Limits count of emitted suggestions for spelling mistakes.
+max-spelling-suggestions=4
+
+# Spelling dictionary name. Available dictionaries: none. To make it work,
+# install the 'python-enchant' package.
+spelling-dict=
+
+# List of comma separated words that should be considered directives if they
+# appear at the beginning of a comment and should not be checked.
+spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:
+
+# List of comma separated words that should not be checked.
+spelling-ignore-words=
+
+# A path to a file that contains the private dictionary; one word per line.
+spelling-private-dict-file=
+
+# Tells whether to store unknown words to the private dictionary (see the
+# --spelling-private-dict-file option) instead of raising a message.
+spelling-store-unknown-words=no
+
+
+[STRING]
+
+# This flag controls whether inconsistent-quotes generates a warning when the
+# character used as a quote delimiter is used inconsistently within a module.
+check-quote-consistency=no
+
+# This flag controls whether the implicit-str-concat should generate a warning
+# on implicit string concatenation in sequences defined over several lines.
+check-str-concat-over-line-jumps=no
+
+
+[TYPECHECK]
+
+# List of decorators that produce context managers, such as
+# contextlib.contextmanager. Add to this list to register other decorators that
+# produce valid context managers.
+contextmanager-decorators=contextlib.contextmanager
+
+# List of members which are set dynamically and missed by pylint inference
+# system, and so shouldn't trigger E1101 when accessed. Python regular
+# expressions are accepted.
+generated-members=
+
+# Tells whether to warn about missing members when the owner of the attribute
+# is inferred to be None.
+ignore-none=yes
+
+# This flag controls whether pylint should warn about no-member and similar
+# checks whenever an opaque object is returned when inferring. The inference
+# can return multiple potential results while evaluating a Python object, but
+# some branches might not be evaluated, which results in partial inference. In
+# that case, it might be useful to still emit no-member and other checks for
+# the rest of the inferred objects.
+ignore-on-opaque-inference=yes
+
+# List of symbolic message names to ignore for Mixin members.
+ignored-checks-for-mixins=no-member,
+ not-async-context-manager,
+ not-context-manager,
+ attribute-defined-outside-init
+
+# List of class names for which member attributes should not be checked (useful
+# for classes with dynamically set attributes). This supports the use of
+# qualified names.
+ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace
+
+# Show a hint with possible names when a member name was not found. The aspect
+# of finding the hint is based on edit distance.
+missing-member-hint=yes
+
+# The minimum edit distance a name should have in order to be considered a
+# similar match for a missing member name.
+missing-member-hint-distance=1
+
+# The total number of similar names that should be taken in consideration when
+# showing a hint for a missing member.
+missing-member-max-choices=1
+
+# Regex pattern to define which classes are considered mixins.
+mixin-class-rgx=.*[Mm]ixin
+
+# List of decorators that change the signature of a decorated function.
+signature-mutators=
+
+
+[VARIABLES]
+
+# List of additional names supposed to be defined in builtins. Remember that
+# you should avoid defining new builtins when possible.
+additional-builtins=
+
+# Tells whether unused global variables should be treated as a violation.
+allow-global-unused-variables=yes
+
+# List of names allowed to shadow builtins
+allowed-redefined-builtins=
+
+# List of strings which can identify a callback function by name. A callback
+# name must start or end with one of those strings.
+callbacks=cb_,
+ _cb
+
+# A regular expression matching the name of dummy variables (i.e. expected to
+# not be used).
+dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
+
+# Argument names that match this expression will be ignored.
+ignored-argument-names=_.*|^ignored_|^unused_
+
+# Tells whether we should check for unused import in __init__ files.
+init-import=no
+
+# List of qualified module names which can have objects that can redefine
+# builtins.
+redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md
new file mode 100644
index 00000000..a1777671
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/README.md
@@ -0,0 +1,46 @@
+# FMS Acceleration for Fused Operations and Kernels
+
+This library contains fused operations and custom kernels, to be expanded over time. Currently it contains the following:
+
+
+1. Fused operations and kernels extracted from [unsloth](#extracted-code-from-unsloth).
+ - Low-Rank Adapter Fused Operations
+ - Fast RoPE Triton Kernels
+ - Fast RMS LayerNorm Triton Kernels
+ - Fast Cross Entropy Triton Kernels
+
+## Plugins
+
+Plugin | Description | Depends | Loading | Augmentation | Callbacks
+--|--|--|--|--|--
+[fast_quantized_peft](./src/fms_accelerate_foak/framework_plugin_fast_quantized_peft.py) | LoRA fused ops, fast cross-entropy, fast rms, fast RoPE | Contains extracted code | | ✅
+
+### Code Extracted from Unsloth
+
+
+Notes on the extraction of code from [unsloth](https://github.com/unslothai/unsloth):
+- While unsloth is [released under Apache 2.0](https://github.com/unslothai/unsloth/blob/main/LICENSE), there are comments indicating some exceptions strewn throughout the code base, see [an example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1140-L1143).
+ ```
+ it would require a commercial license if used to run on more than 4 GPUs ...
+ ```
+- These exceptions appear to be located around the trainer improvements, see [another example here](https://github.com/unslothai/unsloth/blob/ec19e61c854dcf9104386fa63fc6c4f2944d4f35/unsloth/models/llama.py#L1177-L1183).
+- These exceptions appear around [Feb 2024 Release](https://github.com/unslothai/unsloth/commit/3e4c5a323c16bbda2c92212b790073c4e99c2a55); any code that appears in any file where such exceptions occur **is not extracted**.
+- Instead in its place, we have adopted a different approach; we adopt the approach of model patching, as opposed unsloths' approach to rewrite the model. Our approach is novel and **completely rewritten from scratch**.
+- All extracted code appears before the Feb 2024 Release.
+- In the table below we record what was extracted, and the exact commit from which it was taken.
+
+Path | Description | Extracted From | Modifications | Date
+--|--|--|--|--
+[fused_ops/unsloth_lora](./src/fms_acceleration_foak/fused_ops/unsloth_lora) | QLoRA fast dequant, activation kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | | 28 Jan 2024
+[fused_ops/unsloth_lora/bnb](./src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb) | BNB fast lora | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `fast_lora.py` | 28 Jan 2024
+[fused_ops/unsloth_lora/gptq](./src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq) | GPTQ fast dequant (triton_v2) | `jeromeku/main` @ [2839d39](https://github.com/jeromeku/unsloth/commit/2839d390ef3bb318904289bfb9a7751a782c4e44) | `fast_lora.py`
`triton/layers.py` | 6 Feb 2024
+[kernels/unsloth](./src/fms_acceleration_foak/kernels/unsloth) | Fast RMS, RoPE, CrossEnt kernels | `unsloth/main` @ [1ecc0185](https://github.com/unslothai/unsloth/commit/1ecc0185a5759c7a0c95dfc96aceea5023cebdfc) | `cross_entropy_loss.py`
`rms_layernorm.py` | 28 Jan 2024
+
+## Known Issues
+
+- MixedPrecision `--fp16` should be used `fast_lora`. Also consider loading the model in `torch.float16`.
+- `fast_lora` has issues with FSDP with the `peft` style of FSDP wrapping.
+ * This is because the adapter's forward functions are bypassed in the fused ops.
+ * For AutoGPTQ this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
+ * However for QLoRA this is not yet done https://github.com/foundation-model-stack/fms-acceleration/issues/3.
+- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results.
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml
new file mode 100644
index 00000000..2151beb3
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml
@@ -0,0 +1,27 @@
+# PEFT-related acceleration
+peft:
+
+ # quantization-releated acceleration
+ # e.g., kernels for quantized base weights
+ quantization:
+
+ fused_ops_and_kernels:
+
+ # load unsloth optimizations for these 4bit base layer weights.
+ # currently only support "auto_gptq" and "bitsandbytes"
+ base_layer: auto_gptq
+
+ # activate various unsloth optimizations
+ # NOTE: currently supports only all-or-nothing.
+
+ # fused kernels for lora linear layers
+ fused_lora: True
+
+ # fast loss triton kernels
+ fast_loss: True
+
+ # fast rms norm triton kernels
+ fast_rsm_layernorm: True
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: True
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/pyproject.toml b/plugins/fused-ops-and-kernels/pyproject.toml
new file mode 100644
index 00000000..2b2aef78
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/pyproject.toml
@@ -0,0 +1,31 @@
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[project]
+name = "fms-acceleration-foak"
+version = '0.0.1'
+description = "FMS Acceleration using Fused Operations and Kernels"
+authors = [
+ {name = "Fabian Lim", email = "flim@sg.ibm.com"},
+ {name = "Aaron Chew", email = "aaron.chew1@ibm.com"},
+]
+license = {text = "Apache-2.0"}
+readme = "README.md"
+requires-python = "~=3.9"
+keywords = ['fms-hf-tuning', 'acceleration', 'fused-ops', 'triton']
+classifiers=[
+ "License :: OSI Approved :: Apache Software License",
+ "Development Status :: 4 - Beta",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+]
+dependencies = ['pandas']
+
+[tool.hatch.build.targets.wheel]
+only-include = ["src/fms_acceleration_foak"]
+
+[tool.hatch.build.targets.wheel.sources]
+"src" = ""
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py
new file mode 100644
index 00000000..edf3f23d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/__init__.py
@@ -0,0 +1,16 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .framework_plugin_fast_quantized_peft import FastQuantizedPeftAccelerationPlugin
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
new file mode 100644
index 00000000..01a5b4b7
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_quantized_peft.py
@@ -0,0 +1,181 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from typing import Callable, Dict, Tuple
+
+# Third Party
+from accelerate.utils import set_module_tensor_to_device
+from fms_acceleration import AccelerationPlugin
+from peft import LoraConfig
+from peft.tuners.lora.layer import LoraLayer
+from transformers import TrainingArguments
+from transformers.utils import logging
+import torch
+import torch.distributed as dist
+
+# want to use the transformers logger, but a bit of pain
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+logger.setLevel(logging._get_default_logging_level())
+logger.addHandler(logging._default_handler)
+
+
+def log_patch_summary(
+ logging_func: Callable = None,
+):
+ if logging_func is None:
+ logging_func = print
+
+ # this is a guarded import, because the model rule registration
+ # does not need to be loaded unless patch_model is required
+ # Local
+ from .models.model_patcher import ( # pylint: disable=import-outside-toplevel
+ patch_model_summary,
+ )
+
+ for line in patch_model_summary().split("\n"):
+ logging_func(line)
+
+
+# consider moving this somewhere else later
+def lora_adapters_switch_ddp_from_fsdp(modules, fsdp_plugin):
+ """
+ This function installs hooks on the target adapter parameters and
+ reduces the accumulated gradients across devices
+ """
+
+ # NOTE: assuming lora has no bias
+ fsdp_plugin.ignored_modules = []
+ for mod in modules:
+ fsdp_plugin.ignored_modules.append(mod.lora_A)
+ fsdp_plugin.ignored_modules.append(mod.lora_B)
+
+ def _all_reduce_hook(grad):
+ if grad is not None:
+ grad = grad.contiguous()
+ dist.all_reduce(grad, op=dist.ReduceOp.AVG, group=None)
+ return grad
+
+ for mod in modules:
+ A = mod.lora_A.default
+ B = mod.lora_B.default
+
+ # install hooks on the adapters
+ A.weight.register_hook(_all_reduce_hook)
+ B.weight.register_hook(_all_reduce_hook)
+
+ # because we will ignore these from FSDP, we need to manually
+ # move them to gpu if they are already not on them
+ if not A.weight.is_cuda:
+ set_module_tensor_to_device(A, "weight", "cuda")
+ if not B.weight.is_cuda:
+ set_module_tensor_to_device(B, "weight", "cuda")
+
+
+class FastQuantizedPeftAccelerationPlugin(AccelerationPlugin):
+
+ # NOTE: may remove this when we have generic model rules
+ restricted_model_archs = [
+ "MixtralForCausalLM",
+ "LlamaForCausalLM",
+ "MistralForCausalLM",
+ ]
+
+ def __init__(self, configurations: Dict[str, Dict]):
+ super().__init__(configurations)
+
+ self._base_layer = self._check_config_and_maybe_check_values(
+ key="peft.quantization.fused_ops_and_kernels.base_layer",
+ values=["auto_gptq", "bitsandbytes"],
+ )
+
+ # only support these at the moment
+ self._check_config_equal(
+ key="peft.quantization.fused_ops_and_kernels.fused_lora", value=True
+ )
+ self._check_config_equal(
+ key="peft.quantization.fused_ops_and_kernels.fast_loss", value=True
+ )
+ self._check_config_equal(
+ key="peft.quantization.fused_ops_and_kernels.fast_rsm_layernorm",
+ value=True,
+ )
+ self._check_config_equal(
+ key="peft.quantization.fused_ops_and_kernels.fast_rope_embeddings",
+ value=True,
+ )
+
+ @property
+ def requires_agumentation(self):
+ return True
+
+ def augmentation(
+ self,
+ model,
+ train_args: TrainingArguments,
+ modifiable_args: Tuple[LoraConfig],
+ ):
+ # NOTE: how do I check this now that the modifiable args are missing
+ # assert peft_config.lora_dropout == 0, \
+ # "Fused Attention requires lora_dropout argument to be set to 0"
+
+ # need to check why this is needed
+ assert (
+ model.dtype == torch.float16 and train_args.fp16
+ ), "need to run in fp16 mixed precision or load model in fp16"
+
+ # this is a guarded import, because the model rule registration
+ # does not need to be loaded unless patch_model is required
+ # Local
+ from .models.model_patcher import ( # pylint: disable=import-outside-toplevel
+ patch_model,
+ )
+
+ model = patch_model(model, base_type=self._base_layer)
+ return model, modifiable_args
+
+ def get_callbacks_and_ready_for_train(
+ self, model: torch.nn.Module = None, accelerator=None
+ ):
+
+ # if this is moved to framework, it can be handled as the same way as
+ # log_initialization_message
+ # log the patch summary
+ if accelerator is not None and accelerator.is_main_process:
+ log_patch_summary(logging_func=logger.info)
+
+ callbacks = []
+ if (
+ accelerator is not None
+ and getattr(accelerator.state, "fsdp_plugin", None) is not None
+ ):
+ # This function installs grad reduction hooks on adapters if
+ # FSDP is detected. Because of incompatibility between FSDP and
+ # fused modules, adapters are not sharded - instead
+ # accumulated gradients from adapters in each device are reduced
+ # in these grad reduce hooks
+ # This function might be removed in future if the incompatiblity
+ # is resolved
+ lora_adapters_switch_ddp_from_fsdp(
+ [mod for mod in model.modules() if isinstance(mod, LoraLayer)],
+ accelerator.state.fsdp_plugin,
+ )
+ return callbacks
+
+
+# register
+AccelerationPlugin.register_plugin(
+ FastQuantizedPeftAccelerationPlugin,
+ configuration_and_paths=["peft.quantization.fused_ops_and_kernels"],
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py
new file mode 100644
index 00000000..b994759e
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py
new file mode 100644
index 00000000..a35f05f9
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
+from .geglu import (
+ geglu_exact_forward_kernel,
+ geglu_exact_backward_kernel,
+ geglu_approx_forward_kernel,
+ geglu_approx_backward_kernel,
+)
+from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py
new file mode 100644
index 00000000..a5c556b4
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+
+from .fast_lora import (
+ get_lora_parameters,
+ apply_lora_mlp_swiglu,
+ apply_lora_mlp_geglu_exact,
+ apply_lora_mlp_geglu_approx,
+ apply_lora_qkv,
+ apply_lora_o,
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py
new file mode 100644
index 00000000..71d7070c
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py
@@ -0,0 +1,403 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from ..utils import fast_dequantize, QUANT_STATE, get_lora_parameters, matmul_lora
+
+
+class LoRA_MLP(torch.autograd.Function):
+ """
+ ### LoRA weights
+ G = G + Ag @ Bg
+ U = U + Au @ Bu
+ W = W + Aw @ Bw
+
+ ### SwiGLU(X)
+ e = X @ G
+ f = e * sigmoid(e)
+ g = X @ U
+ h = f * g
+ i = h @ W
+
+ ### Backpropagation chain rule
+ See our blog post for more details
+
+ df = sigmoid(e) * (1 - f) + f
+ dC/dW = h.T @ dY
+ dC/dU = X.T @ (D @ W.T * f)
+ dC/dG = X.T @ (D @ W.T * df * g)
+
+ ### Down projection LoRA weights
+ dC/dAw = dC/dW @ B.T
+ dC/dBw = A.T @ dC/dW
+ dC/dAw = h.T @ dY @ B.T
+ dC/dBw = A.T @ h.T @ dY
+
+ ### Up projection LoRA weights
+ dC/dAu = X.T @ (D @ W.T * f) @ B.T
+ dC/dBu = A.T @ X.T @ (D @ W.T * f)
+
+ ### Gate projection LoRA weights
+ dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
+ dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
+
+ Don't forget to see our blog post for more details!
+ """
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(ctx, X : torch.Tensor,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ _forward_function, _backward_function,):
+ dtype = X.dtype
+
+ e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
+ g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
+ h = _forward_function(e, g)
+ i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
+
+ ctx.custom_saved_tensors = (
+ gateW, gateW_quant, gateS,
+ upW, upW_quant, upS,
+ downW, downW_quant, downS,
+ _backward_function,
+ )
+ ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
+ X, e, g)
+ return i
+ pass
+
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dY : torch.Tensor):
+ gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
+ _backward_function = ctx.custom_saved_tensors
+ gateA, gateB, upA, upB, downA, downB, \
+ X, e, g = ctx.saved_tensors
+
+ gateA, gateB, upA, upB, downA, downB = \
+ gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
+
+ batch, seq_len, hd = X.shape
+ dY = dY.view(-1, dY.shape[-1])
+ X = X .view(-1, X .shape[-1])
+ e = e .view(-1, e .shape[-1])
+ g = g .view(-1, g .shape[-1])
+ dtype = X.dtype
+
+ DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
+ DW, e, g = _backward_function(DW, e, g)
+ h, df, de = DW, e, g
+
+ # Down projection LoRA weights
+ d_downA = h.t() @ (dY @ downB.t())
+ d_downB = (downA.t() @ h.t()) @ dY
+ d_downA *= downS
+ d_downB *= downS
+
+ # Up projection LoRA weights
+ d_upA = X.t() @ (df @ upB.t())
+ d_upB = (upA.t() @ X.t()) @ df
+ d_upA *= upS
+ d_upB *= upS
+
+ # Gate projection LoRA weights
+ d_gateA = X.t() @ (de @ gateB.t())
+ d_gateB = (gateA.t() @ X.t()) @ de
+ d_gateA *= gateS
+ d_gateB *= gateS
+
+ # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
+ # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
+ upW = fast_dequantize(upW.t(), upW_quant)
+ dX = torch.matmul(df, upW.t(), out = X)
+ del upW
+ dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
+
+ gateW = fast_dequantize(gateW.t(), gateW_quant)
+ dX += de @ gateW.t()
+ del gateW
+ dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
+
+ # gateW, gateW_quant, gateA, gateB, gateS,
+ # upW, upW_quant, upA, upB, upS,
+ # downW, downW_quant, downA, downB, downS,
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_gateA.t(), d_gateB.t(), None, \
+ None, None, d_upA.t(), d_upB.t(), None, \
+ None, None, d_downA.t(), d_downB.t(), None, \
+ None, None, # _backward and _forward
+ pass
+pass
+
+
+from ..swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
+def apply_lora_mlp_swiglu(self, X):
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,)
+ return out
+pass
+
+
+from ..geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
+def apply_lora_mlp_geglu_exact(self, X):
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ geglu_exact_forward_kernel, geglu_exact_backward_kernel,)
+ return out
+pass
+
+
+from ..geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
+def apply_lora_mlp_geglu_approx(self, X):
+ gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
+ upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
+ downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
+ out = LoRA_MLP.apply(X,
+ gateW, gateW_quant, gateA, gateB, gateS,
+ upW, upW_quant, upA, upB, upS,
+ downW, downW_quant, downA, downB, downS,
+ geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
+ return out
+pass
+
+
+class LoRA_QKV(torch.autograd.Function):
+ """
+ ### LoRA weights
+ Wq = Wq + Aq @ Bq
+ Wk = Wk + Ak @ Bk
+ Wv = Wv + Av @ Bv
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
+
+ ### Backpropagation chain rule
+ See our blogpost for more details.
+
+ dC/dWq = X.T @ D(Wq)
+ dC/dWk = X.T @ D(Wk)
+ dC/dWv = X.T @ D(Wv)
+ We then sum them all find dC/dX
+
+ ### Q projection LoRA weights
+ dC/dAq = X.T @ D(Wq) @ B.T
+ dC/dBq = A.T @ X.T @ D(Wq)
+
+ ### K projection LoRA weights
+ dC/dAk = X.T @ D(Wk) @ B.T
+ dC/dBk = A.T @ X.T @ D(Wk)
+
+ ### V projection LoRA weights
+ dC/dAv = X.T @ D(Wv) @ B.T
+ dC/dBv = A.T @ X.T @ D(Wv)
+ """
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(ctx, X : torch.Tensor,
+ QW, QW_quant, QA, QB, QS,
+ KW, KW_quant, KA, KB, KS,
+ VW, VW_quant, VA, VB, VS,):
+ dtype = X.dtype
+
+ Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
+ K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
+ V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
+
+ ctx.custom_saved_tensors = (
+ QW, QW_quant, QS,
+ KW, KW_quant, KS,
+ VW, VW_quant, VS,
+ )
+ ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
+ return Q, K, V
+ pass
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dQ, dK, dV):
+ QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
+ ctx.custom_saved_tensors
+ X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
+
+ QA, QB, KA, KB, VA, VB = \
+ QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
+
+ batch, seq_len, hd = X.shape
+ dQ = dQ.view(-1, dQ.shape[-1])
+ dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
+ dV = dV.view(-1, dV.shape[-1])
+ X = X .view(-1, X .shape[-1])
+ dtype = X.dtype
+
+ ### Weight projection LoRA weights
+ # See our blogpost for more details.
+
+ # Q Projection
+ d_QA = X.t() @ (dQ @ QB.t())
+ d_QB = (QA.t() @ X.t()) @ dQ
+ d_QA *= QS
+ d_QB *= QS
+
+ # K Projection
+ d_KA = X.t() @ (dK @ KB.t())
+ d_KB = (KA.t() @ X.t()) @ dK
+ d_KA *= KS
+ d_KB *= KS
+
+ # V Projection
+ d_VA = X.t() @ (dV @ VB.t())
+ d_VB = (VA.t() @ X.t()) @ dV
+ d_VA *= VS
+ d_VB *= VS
+
+ # Combine derivatives to find dX
+ # dQ
+ QW = fast_dequantize(QW.t(), QW_quant)
+ dX = torch.matmul(dQ, QW.t(), out = X)
+ del QW
+ dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
+
+ # dK
+ KW = fast_dequantize(KW.t(), KW_quant)
+ dX += dK @ KW.t()
+ del KW
+ dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
+
+ # dV
+ VW = fast_dequantize(VW.t(), VW_quant)
+ dX += dV @ VW.t()
+ del VW
+ dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
+
+ # QW, QW_quant, QA, QB, QS,
+ # KW, KW_quant, KA, KB, KS,
+ # VW, VW_quant, VA, VB, VS,
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_QA.t(), d_QB.t(), None, \
+ None, None, d_KA.t(), d_KB.t(), None, \
+ None, None, d_VA.t(), d_VB.t(), None
+ pass
+pass
+
+
+def apply_lora_qkv(self, X):
+ QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
+ KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
+ VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
+ Q, K, V = LoRA_QKV.apply(X,
+ QW, QW_quant, QA, QB, QS,
+ KW, KW_quant, KA, KB, KS,
+ VW, VW_quant, VA, VB, VS,
+ )
+ return Q, K, V
+pass
+
+
+class LoRA_W(torch.autograd.Function):
+ """
+ ### LoRA weights
+ Wq = Wq + Aq @ Bq
+ Wk = Wk + Ak @ Bk
+ Wv = Wv + Av @ Bv
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
+
+ ### Backpropagation chain rule
+ dC/dWq = X.T @ D(Wq)
+ dC/dWk = X.T @ D(Wk)
+ dC/dWv = X.T @ D(Wv)
+
+ ### Q projection LoRA weights
+ dC/dAq = X.T @ D(Wq) @ B.T
+ dC/dBq = A.T @ X.T @ D(Wq)
+
+ ### K projection LoRA weights
+ dC/dAk = X.T @ D(Wk) @ B.T
+ dC/dBk = A.T @ X.T @ D(Wk)
+
+ ### V projection LoRA weights
+ dC/dAv = X.T @ D(Wv) @ B.T
+ dC/dBv = A.T @ X.T @ D(Wv)
+ """
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(ctx, X : torch.Tensor,
+ W, W_quant, A, B, S):
+ dtype = X.dtype
+ XW = matmul_lora(X, W, W_quant, A, B, S)
+ ctx.custom_saved_tensors = (W, W_quant, S,)
+ ctx.save_for_backward(A, B, X)
+ return XW
+ pass
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dY : torch.Tensor):
+ W, W_quant, S = ctx.custom_saved_tensors
+ A, B, X = ctx.saved_tensors
+
+ A, B = A.t(), B.t()
+
+ batch, seq_len, hd = X.shape
+ dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
+ X = X .reshape(-1, X .shape[-1]) # Must be reshape
+ dtype = X.dtype
+
+ ### Weight projection LoRA weights
+ # Weight projection
+ d_A = X.t() @ (dY @ B.t())
+ d_B = (A.t() @ X.t()) @ dY
+ d_A *= S
+ d_B *= S
+
+ # Get derivative for dX
+ W = fast_dequantize(W.t(), W_quant)
+ dX = dY @ W.t()
+ del W
+ dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
+
+ # W, W_quant, A, B, S
+ return dX.view(batch, seq_len, hd), \
+ None, None, d_A.t(), d_B.t(), None
+ pass
+pass
+
+
+def apply_lora_o(self, X):
+ OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
+ O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
+ return O
+pass
+
+# added by flim@sg.ibm.com
+# this will be patchable on the actual module
+def apply_lora_o_v2(self, X):
+ OW, OW_quant, OA, OB, OS = get_lora_parameters(self)
+ O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
+ return O
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py
new file mode 100644
index 00000000..3441c59d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/geglu.py
@@ -0,0 +1,202 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+import torch
+
+
+@triton.jit
+def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
+ # h = f * up
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ h_row = f_row * g_row
+
+ # Store h
+ tl.store(h + offsets, h_row, mask = mask)
+pass
+
+
+def geglu_exact_forward_kernel(gate, up):
+ batch, seq_len, hd = gate.shape
+ n_elements = gate.numel()
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
+ return out
+pass
+
+
+@triton.jit
+def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+ """
+ f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
+ h = f * up
+
+ df/de (with help of Wolfram :)
+ df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
+
+ Reuse via
+ f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
+ """
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ # Break e_row away for re-use
+ # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
+ f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
+ f_row = f_partial_row * e_row
+
+ f_row = f_row.to(DW_row.dtype)
+ # h = f * g
+ h_row = f_row * g_row
+ # df = DW * f
+ df_row = DW_row * f_row
+ # dg = DW * g
+ dg_row = DW_row * g_row
+
+ # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
+ t = 0.3989422804014327 # 1/sqrt(2*pi)
+ df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
+
+ de_row = dg_row.to(tl.float32) * df_de
+ de_row = de_row.to(DW_row.dtype)
+
+ # Store derivatives in buffers
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
+
+
+def geglu_exact_backward_kernel(DW, e, g):
+ batch_seq_len, hd = e.shape
+ n_elements = e.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+ return DW, e, g
+pass
+
+
+@triton.jit
+def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
+ # h = f * up
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
+
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ f_row = 0.5 * e_row * (
+ tl.math.tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
+ + 1.0
+ )
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ h_row = f_row * g_row
+
+ # Store h
+ tl.store(h + offsets, h_row, mask = mask)
+pass
+
+
+def geglu_approx_forward_kernel(gate, up):
+ batch, seq_len, hd = gate.shape
+ n_elements = gate.numel()
+ out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda")
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
+ return out
+pass
+
+
+@triton.jit
+def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+ """
+ f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
+ h = f * up
+
+ df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
+ df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
+ 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
+ ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
+
+ Notice sech^2(x) = 1 - tanh^2(x)
+ So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
+
+ See https://www.desmos.com/calculator/nqprfoni6x
+ """
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ # See https://www.desmos.com/calculator/nqprfoni6x
+ s = 0.7978845608028654 # math.sqrt(2 / math.pi)
+ a = s * e_row # a = sqrt(2 / pi) * x
+ b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
+ T = 1.0 + tl.math.tanh(a + b)
+ T2 = 0.5 * T
+ # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
+ Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
+ df_de = T2 + Q2 # 1/2 * (T + Q)
+
+ # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
+ f_row = T2 * e_row
+ f_row = f_row.to(DW_row.dtype)
+ # h = f * g
+ h_row = f_row * g_row
+ # df = DW * f
+ df_row = DW_row * f_row
+ # dg = DW * g
+ dg_row = DW_row * g_row
+
+ de_row = dg_row.to(tl.float32) * df_de
+ de_row = de_row.to(DW_row.dtype)
+
+ # Store derivatives in buffers
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
+
+
+def geglu_approx_backward_kernel(DW, e, g):
+ batch_seq_len, hd = e.shape
+ n_elements = e.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+ return DW, e, g
+pass
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py
new file mode 100644
index 00000000..b9b793a0
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/__init__.py
@@ -0,0 +1,3 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py
new file mode 100644
index 00000000..ee5055ed
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py
@@ -0,0 +1,744 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
+
+import math
+from dataclasses import dataclass
+from logging import getLogger
+from typing import Optional
+
+import torch
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+from .triton.kernels import dequant248
+from ..swiglu import swiglu_DWf_DW_dfg_kernel, swiglu_fg_kernel
+
+logger = getLogger(__name__)
+
+
+@dataclass
+class GPTQuantState:
+ """
+ Stores params for GPTQ linear layer quantization
+ """
+
+ infeatures: int
+ outfeatures: int
+
+ bits: int
+ group_size: int
+ maxq: int
+ qweight: torch.Tensor
+ qzeros: torch.Tensor
+ scales: torch.Tensor
+ g_idx: torch.Tensor
+
+ # cuda_kernel params (not used currently)
+ kernel_switch_threshold: int
+ autogptq_cuda_available: bool = False
+ autogptq_cuda: bool = False
+
+ wf: Optional[torch.Tensor] = None
+ use_cuda_fp16: bool = False
+
+ bias: Optional[torch.Tensor] = None
+ trainable: bool = True
+
+
+def unpack_gptqstate(qstate):
+ qweight, scales, qzeros, g_idx, bits = (
+ qstate.qweight,
+ qstate.scales,
+ qstate.qzeros,
+ qstate.g_idx,
+ qstate.bits,
+ )
+ return qweight, scales, qzeros, g_idx, bits
+
+
+def extract_gptq_state(qmodule):
+ if hasattr(qmodule, "base_layer"):
+ qmodule = qmodule.base_layer
+
+ def check_bias(qmodule):
+ if hasattr(qmodule, "bias") and qmodule.bias is not None:
+ if qmodule.bias.count_nonzero() > 0:
+ return qmodule.bias
+ return None
+
+ return GPTQuantState(
+ infeatures=qmodule.infeatures,
+ outfeatures=qmodule.outfeatures,
+ bits=qmodule.bits,
+ group_size=qmodule.group_size,
+ maxq=qmodule.maxq,
+ qweight=qmodule.qweight.cuda(),
+ qzeros=qmodule.qzeros.cuda(),
+ scales=qmodule.scales.cuda(),
+ g_idx=qmodule.g_idx.cuda(),
+ bias=check_bias(qmodule),
+ wf=qmodule.wf.cuda() if hasattr(qmodule, "wf") else None,
+ kernel_switch_threshold=(
+ qmodule.kernel_switch_threshold
+ if hasattr(qmodule, "kernel_switch_threshold")
+ else None
+ ),
+ autogptq_cuda_available=( # fixed by @aaron.chew1@sg.ibm.com
+ qmodule.autogptq_cuda_available
+ if hasattr(qmodule, "autogptq_cuda_available") else False
+ ),
+ # use_cuda_fp16=qmodule.use_cuda_fp16,
+ )
+
+
+def get_lora_parameters(proj):
+ # For DPO or disabled adapters
+ base_layer = proj.base_layer if hasattr(proj, "base_layer") else proj
+ qstate = extract_gptq_state(base_layer)
+
+ if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
+ return qstate, None, None, None
+
+ active_adapter = (
+ proj.active_adapters[0]
+ if hasattr(proj, "active_adapters")
+ else proj.active_adapter
+ )
+ A = proj.lora_A[active_adapter].weight
+ B = proj.lora_B[active_adapter].weight
+ s = proj.scaling[active_adapter]
+ return qstate, A, B, s
+
+
+def matmul_lora_canonicalized(X, W, A, B, s):
+ """
+ X: rank-2 tensor (batch, seq_len) x (din)
+ W: rank-2 tensor (din, dout)
+ out: rank-2 tensor (batch, seq_len) x (dout)
+ din = X.shape[1]
+ dout = W.shape[1]
+ """
+
+ out = torch.matmul(X, W)
+
+ A, B = A.t(), B.t()
+ out += (X @ A) @ (s * B)
+
+ return out
+
+
+def matmul_lora(X, W, A, B, s, out=None):
+ dtype = X.dtype
+
+ if X.dim() == 3:
+ batch, seq_len, d = X.shape
+ X = X.view(-1, X.shape[-1])
+ reshape = True
+ else:
+ reshape = False
+
+ out = torch.matmul(X, W, out=out)
+
+ if A is not None:
+ # LoRA is enabled
+ A, B = A.t(), B.t()
+ out += (X @ A.to(dtype)) @ (s * B.to(dtype))
+
+ return out.view(batch, seq_len, -1) if reshape else out
+
+
+class LoRA_MLP(torch.autograd.Function):
+ """
+ ### LoRA weights
+ G = G + Ag @ Bg
+ U = U + Au @ Bu
+ W = W + Aw @ Bw
+
+ ### SwiGLU(X)
+ e = X @ G
+ f = e * sigmoid(e)
+ g = X @ U
+ h = f * g
+ i = h @ W
+
+ ### Backpropagation chain rule
+ See our blog post for more details
+
+ df = sigmoid(e) * (1 - f) + f
+ dC/dW = h.T @ dY
+ dC/dU = X.T @ (D @ W.T * f)
+ dC/dG = X.T @ (D @ W.T * df * g)
+
+ ### Down projection LoRA weights
+ dC/dAw = dC/dW @ B.T
+ dC/dBw = A.T @ dC/dW
+ dC/dAw = h.T @ dY @ B.T
+ dC/dBw = A.T @ h.T @ dY
+
+ ### Up projection LoRA weights
+ dC/dAu = X.T @ (D @ W.T * f) @ B.T
+ dC/dBu = A.T @ X.T @ (D @ W.T * f)
+
+ ### Gate projection LoRA weights
+ dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
+ dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
+
+ Don't forget to see our blog post for more details!
+ """
+
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(
+ ctx,
+ X: torch.Tensor,
+ gate_qweight,
+ gate_scales,
+ gate_qzeros,
+ gate_g_idx,
+ gate_bits,
+ gateA,
+ gateB,
+ gateS,
+ up_qweight,
+ up_scales,
+ up_qzeros,
+ up_g_idx,
+ up_bits,
+ upA,
+ upB,
+ upS,
+ down_qweight,
+ down_scales,
+ down_qzeros,
+ down_g_idx,
+ down_bits,
+ downA,
+ downB,
+ downS,
+ ):
+ dtype = X.dtype
+
+ # Separate dequant248 from matmul
+ gateW = dequant248(
+ gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits
+ )
+ e = matmul_lora(X, gateW, gateA, gateB, gateS)
+ upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
+ g = matmul_lora(X, upW, upA, upB, upS)
+ # f = torch.nn.functional.silu(e)
+ # h = f * g
+ h = swiglu_fg_kernel(e, g)
+
+ downW = dequant248(
+ down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
+ )
+ i = matmul_lora(h, downW, downA, downB, downS)
+
+ ctx.custom_saved_tensors = (
+ gate_qweight,
+ gate_scales,
+ gate_qzeros,
+ gate_g_idx,
+ gate_bits,
+ gateS,
+ up_qweight,
+ up_scales,
+ up_qzeros,
+ up_g_idx,
+ up_bits,
+ upS,
+ down_qweight,
+ down_scales,
+ down_qzeros,
+ down_g_idx,
+ down_bits,
+ downS,
+ )
+ ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB, X, e, g)
+ return i
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dY: torch.Tensor):
+ (
+ gate_qweight,
+ gate_scales,
+ gate_qzeros,
+ gate_g_idx,
+ gate_bits,
+ gateS,
+ up_qweight,
+ up_scales,
+ up_qzeros,
+ up_g_idx,
+ up_bits,
+ upS,
+ down_qweight,
+ down_scales,
+ down_qzeros,
+ down_g_idx,
+ down_bits,
+ downS,
+ ) = ctx.custom_saved_tensors
+ gateA, gateB, upA, upB, downA, downB, X, e, g = ctx.saved_tensors
+
+ gateA, gateB, upA, upB, downA, downB = (
+ gateA.t(),
+ gateB.t(),
+ upA.t(),
+ upB.t(),
+ downA.t(),
+ downB.t(),
+ )
+
+ batch, seq_len, hd = X.shape
+ dY = dY.view(-1, dY.shape[-1])
+ X = X.view(-1, X.shape[-1])
+ e = e.view(-1, e.shape[-1])
+ g = g.view(-1, g.shape[-1])
+ dtype = X.dtype
+
+ downW = dequant248(
+ down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
+ )
+ DW = matmul_lora(dY, downW.t(), downB, downA, downS)
+ # e = e.float()
+ # se = 1.0 / (1.0 + torch.exp(-e))
+ # f = (se * e).to(dtype)
+ # h = f * g
+ # df = DW * f
+ # dg = DW * g
+ # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
+ DW, e, g = swiglu_DWf_DW_dfg_kernel(DW, e, g)
+ h, df, de = DW, e, g
+
+ # Down projection LoRA weights
+ d_downA = h.t() @ (dY @ downB.t())
+ d_downB = (downA.t() @ h.t()) @ dY
+ d_downA *= downS
+ d_downB *= downS
+
+ # Up projection LoRA weights
+ d_upA = X.t() @ (df @ upB.t())
+ d_upB = (upA.t() @ X.t()) @ df
+ d_upA *= upS
+ d_upB *= upS
+
+ # Gate projection LoRA weights
+ d_gateA = X.t() @ (de @ gateB.t())
+ d_gateB = (gateA.t() @ X.t()) @ de
+ d_gateA *= gateS
+ d_gateB *= gateS
+
+ # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
+ # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
+ upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
+ dX = torch.matmul(df, upW.t()) # , out=X)
+ del upW
+ dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
+
+ gateW = dequant248(
+ gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits
+ )
+ dX += de @ gateW.t()
+ del gateW
+ dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
+
+ # qweight, scales, qzeros, g_idx, bits
+ # upW, upW_quant, upA, upB, upS,
+ # downW, downW_quant, downA, downB, downS,
+ return (
+ dX.view(batch, seq_len, hd),
+ None, # qweight
+ None, # scales
+ None, # qzeros
+ None, # g_idx
+ None, # bits
+ d_gateA.t(),
+ d_gateB.t(),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_upA.t(),
+ d_upB.t(),
+ None, # dS
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_downA.t(),
+ d_downB.t(),
+ None,
+ )
+
+
+def apply_lora_mlp(self, X):
+ gateQstate, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
+ upQState, upA, upB, upS = get_lora_parameters(self.up_proj)
+ downQState, downA, downB, downS = get_lora_parameters(self.down_proj)
+ out = LoRA_MLP.apply(
+ X,
+ *unpack_gptqstate(gateQstate),
+ gateA,
+ gateB,
+ gateS,
+ *unpack_gptqstate(upQState),
+ upA,
+ upB,
+ upS,
+ *unpack_gptqstate(downQState),
+ downA,
+ downB,
+ downS,
+ )
+ return out
+
+
+class LoRA_QKV(torch.autograd.Function):
+ """
+ ### LoRA weights
+ Wq = Wq + Aq @ Bq
+ Wk = Wk + Ak @ Bk
+ Wv = Wv + Av @ Bv
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
+
+ ### Backpropagation chain rule
+ See our blogpost for more details.
+
+ dC/dWq = X.T @ D(Wq)
+ dC/dWk = X.T @ D(Wk)
+ dC/dWv = X.T @ D(Wv)
+ We then sum them all find dC/dX
+
+ ### Q projection LoRA weights
+ dC/dAq = X.T @ D(Wq) @ B.T
+ dC/dBq = A.T @ X.T @ D(Wq)
+
+ ### K projection LoRA weights
+ dC/dAk = X.T @ D(Wk) @ B.T
+ dC/dBk = A.T @ X.T @ D(Wk)
+
+ ### V projection LoRA weights
+ dC/dAv = X.T @ D(Wv) @ B.T
+ dC/dBv = A.T @ X.T @ D(Wv)
+ """
+
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(
+ ctx,
+ X: torch.Tensor,
+ Q_qweight,
+ Q_scales,
+ Q_qzeros,
+ Q_g_idx,
+ Q_bits,
+ QA,
+ QB,
+ QS,
+ K_qweight,
+ K_scales,
+ K_qzeros,
+ K_g_idx,
+ K_bits,
+ KA,
+ KB,
+ KS,
+ V_qweight,
+ V_scales,
+ V_qzeros,
+ V_g_idx,
+ V_bits,
+ VA,
+ VB,
+ VS,
+ ):
+ dtype = X.dtype
+
+ QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)
+ KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits)
+ VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits)
+ Q = matmul_lora(X, QW, QA, QB, QS)
+ K = matmul_lora(X, KW, KA, KB, KS)
+ V = matmul_lora(X, VW, VA, VB, VS)
+
+ ctx.custom_saved_tensors = (
+ Q_qweight,
+ Q_scales,
+ Q_qzeros,
+ Q_g_idx,
+ Q_bits,
+ QS,
+ K_qweight,
+ K_scales,
+ K_qzeros,
+ K_g_idx,
+ K_bits,
+ KS,
+ V_qweight,
+ V_scales,
+ V_qzeros,
+ V_g_idx,
+ V_bits,
+ VS,
+ )
+ ctx.save_for_backward(
+ X,
+ QA,
+ QB,
+ KA,
+ KB,
+ VA,
+ VB,
+ )
+ return Q, K, V
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dQ, dK, dV):
+ (
+ Q_qweight,
+ Q_scales,
+ Q_qzeros,
+ Q_g_idx,
+ Q_bits,
+ QS,
+ K_qweight,
+ K_scales,
+ K_qzeros,
+ K_g_idx,
+ K_bits,
+ KS,
+ V_qweight,
+ V_scales,
+ V_qzeros,
+ V_g_idx,
+ V_bits,
+ VS,
+ ) = ctx.custom_saved_tensors
+ (
+ X,
+ QA,
+ QB,
+ KA,
+ KB,
+ VA,
+ VB,
+ ) = ctx.saved_tensors
+
+ QA, QB, KA, KB, VA, VB = QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
+
+ batch, seq_len, hd = X.shape
+ dQ = dQ.view(-1, dQ.shape[-1])
+ dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
+ dV = dV.view(-1, dV.shape[-1])
+ X = X.view(-1, X.shape[-1])
+ dtype = X.dtype
+
+ ### Weight projection LoRA weights
+ # See our blogpost for more details.
+
+ # Q Projection
+ d_QA = X.t() @ (dQ @ QB.t())
+ d_QB = (QA.t() @ X.t()) @ dQ
+ d_QA *= QS
+ d_QB *= QS
+
+ # K Projection
+ d_KA = X.t() @ (dK @ KB.t())
+ d_KB = (KA.t() @ X.t()) @ dK
+ d_KA *= KS
+ d_KB *= KS
+
+ # V Projection
+ d_VA = X.t() @ (dV @ VB.t())
+ d_VB = (VA.t() @ X.t()) @ dV
+ d_VA *= VS
+ d_VB *= VS
+
+ # Combine derivatives to find dX
+ # dQ
+ QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)
+ dX = torch.matmul(dQ, QW.t()) # , out=X)
+ del QW
+ dX += dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())
+
+ # dK
+ KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits)
+ dX += dK @ KW.t()
+ del KW
+ dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
+
+ # dV
+ VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits)
+ dX += dV @ VW.t()
+ del VW
+ dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
+
+ # Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, QA, QB, QS,
+ # K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, KA, KB, KS,
+ # V_qweight, V_scales, V_qzeros, V_wf, V_g_idx, V_bits, VA, VB, VS,
+ return (
+ dX.view(batch, seq_len, hd),
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_QA.t(),
+ d_QB.t(),
+ None, # d_QS.t(),
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_KA.t(),
+ d_KB.t(),
+ None, # d_KS.t(),
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_VA.t(),
+ d_VB.t(),
+ None,
+ )
+
+
+def apply_lora_qkv(self, X):
+ Qqstate, QA, QB, QS = get_lora_parameters(self.q_proj)
+ Kqstate, KA, KB, KS = get_lora_parameters(self.k_proj)
+ Vqstate, VA, VB, VS = get_lora_parameters(self.v_proj)
+ Q, K, V = LoRA_QKV.apply(
+ X,
+ *unpack_gptqstate(Qqstate),
+ QA,
+ QB,
+ QS,
+ *unpack_gptqstate(Kqstate),
+ KA,
+ KB,
+ KS,
+ *unpack_gptqstate(Vqstate),
+ VA,
+ VB,
+ VS,
+ )
+ return Q, K, V
+
+
+class LoRA_W(torch.autograd.Function):
+ """
+ ### LoRA weights
+ Wq = Wq + Aq @ Bq
+ Wk = Wk + Ak @ Bk
+ Wv = Wv + Av @ Bv
+ Q = X @ Wq = X @ Wq + X @ Aq @ Bq
+ K = X @ Wk = X @ Wk + X @ Ak @ Bk
+ V = X @ Wv = X @ Wv + X @ Av @ Bv
+
+ ### Backpropagation chain rule
+ dC/dWq = X.T @ D(Wq)
+ dC/dWk = X.T @ D(Wk)
+ dC/dWv = X.T @ D(Wv)
+
+ ### Q projection LoRA weights
+ dC/dAq = X.T @ D(Wq) @ B.T
+ dC/dBq = A.T @ X.T @ D(Wq)
+
+ ### K projection LoRA weights
+ dC/dAk = X.T @ D(Wk) @ B.T
+ dC/dBk = A.T @ X.T @ D(Wk)
+
+ ### V projection LoRA weights
+ dC/dAv = X.T @ D(Wv) @ B.T
+ dC/dBv = A.T @ X.T @ D(Wv)
+ """
+
+ @staticmethod
+ @torch.cuda.amp.custom_fwd
+ def forward(
+ ctx,
+ X: torch.Tensor,
+ O_qweight,
+ O_scales,
+ O_qzeros,
+ O_g_idx,
+ O_bits,
+ A,
+ B,
+ S,
+ ):
+ W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
+ XW = matmul_lora(X, W, A, B, S)
+ del W
+ ctx.custom_saved_tensors = (
+ O_qweight,
+ O_scales,
+ O_qzeros,
+ O_g_idx,
+ O_bits,
+ S,
+ )
+ ctx.save_for_backward(A, B, X)
+ return XW
+
+ @staticmethod
+ @torch.cuda.amp.custom_bwd
+ def backward(ctx, dY: torch.Tensor):
+ O_qweight, O_scales, O_qzeros, O_g_idx, O_bits, S = ctx.custom_saved_tensors
+ A, B, X = ctx.saved_tensors
+
+ A, B = A.t(), B.t()
+
+ batch, seq_len, hd = X.shape
+ dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
+ X = X.reshape(-1, X.shape[-1]) # Must be reshape
+ dtype = X.dtype
+
+ ### Weight projection LoRA weights
+ # Weight projection
+ d_A = X.t() @ (dY @ B.t())
+ d_B = (A.t() @ X.t()) @ dY
+ d_A *= S
+ d_B *= S
+
+ # Get derivative for dX
+ W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
+ dX = dY @ W.t()
+ del W
+ dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
+
+ # O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, A, B, S
+ return (
+ dX.view(batch, seq_len, hd),
+ None,
+ None,
+ None,
+ None,
+ None,
+ d_A.t(),
+ d_B.t(),
+ None,
+ )
+
+
+def apply_lora_o(self, X):
+ Oqstate, OA, OB, OS = get_lora_parameters(self.o_proj)
+ O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
+ return O
+
+# added by flim@sg.ibm.com
+# this version can be directly patched on the output linear
+def apply_lora_o_v2(self, X):
+ Oqstate, OA, OB, OS = get_lora_parameters(self)
+ O = LoRA_W.apply(X, *unpack_gptqstate(Oqstate), OA, OB, OS)
+ return O
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py
new file mode 100644
index 00000000..b9b793a0
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/__init__.py
@@ -0,0 +1,3 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py
new file mode 100644
index 00000000..c252d26d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/kernels.py
@@ -0,0 +1,149 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
+
+import itertools
+from logging import getLogger
+
+import torch
+import triton
+import triton.language as tl
+
+logger = getLogger(__name__)
+
+
+def dequant_ref(qstate):
+ # assert bits == 4, "Only 4-bit quantization is supported"
+ qweight, scales, qzeros, wf, g_idx, bits = (
+ qstate.qweight,
+ qstate.scales,
+ qstate.qzeros,
+ qstate.wf,
+ qstate.g_idx,
+ qstate.bits,
+ )
+
+ zeros = torch.bitwise_right_shift(
+ torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits), wf.unsqueeze(0)
+ ).to(torch.int16 if bits == 8 else torch.int8)
+ zeros = torch.bitwise_and(zeros, (2**bits) - 1)
+
+ zeros = zeros + 1
+ zeros = zeros.reshape(scales.shape)
+
+ weights = torch.bitwise_right_shift(
+ torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1), wf.unsqueeze(-1)
+ ).to(torch.int16 if bits == 8 else torch.int8)
+ weights = torch.bitwise_and(weights, (2**bits) - 1)
+ weights = weights.reshape(weights.shape[0] * weights.shape[1], weights.shape[2])
+ weights = scales[g_idx] * (weights - zeros[g_idx])
+ return weights
+
+
+def make_dequant_configs(block_sizes, num_warps):
+ configs = []
+ for bs, ws in itertools.product(block_sizes, num_warps):
+ configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws))
+ return configs
+
+
+DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8])
+
+
+@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"])
+@triton.jit
+def dequant_kernel_248(
+ g_idx_ptr,
+ scales_ptr,
+ qweight_ptr,
+ qzeros_ptr,
+ out_ptr,
+ numels,
+ maxq: tl.constexpr,
+ bits: tl.constexpr,
+ outfeatures: tl.constexpr,
+ num_groups: tl.constexpr,
+ X_BLOCK: tl.constexpr = 1024,
+):
+ # Block indexing
+ xoffset = tl.program_id(0) * X_BLOCK
+ x_index = xoffset + tl.arange(0, X_BLOCK)
+ xmask = x_index < numels
+ row_idx = x_index // outfeatures
+ col_idx = x_index % outfeatures
+
+ elements_per_feature: tl.constexpr = 32 // bits
+
+ # Load parameters
+ g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last")
+ qweights = tl.load(
+ qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),
+ None,
+ )
+
+ wf_weights = (row_idx % elements_per_feature) * bits
+
+ wf_zeros = (col_idx % elements_per_feature) * bits
+
+ tmp1 = g_idx + num_groups
+ tmp2 = g_idx < 0
+ tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0")
+ groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx
+
+ scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(
+ tl.float32
+ )
+
+ # Unpack weights
+ weights = qweights >> wf_weights # bit shift qweight
+
+ weights = weights & maxq
+
+ # Unpack zeros
+ qzero_ncols: tl.constexpr = outfeatures // elements_per_feature
+ qzeros = tl.load(
+ qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),
+ None,
+ eviction_policy="evict_last",
+ )
+ zeros = qzeros >> wf_zeros
+ zeros = zeros & maxq
+
+ # Dequantize
+ zeros = zeros + 1
+ weights = weights - zeros
+ weights = weights.to(tl.float32)
+ weights = scales * weights
+
+ tl.store(out_ptr + (x_index), weights, mask=xmask)
+
+
+def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
+ """Launcher for triton dequant kernel
+ Only valid for bits = 2, 4, 8
+
+ """
+
+ assert bits in [2, 4, 8], "Only 2, 4, 8-bit GPTQ quantization is supported"
+ num_groups = scales.shape[0]
+ outfeatures = scales.shape[1]
+ infeatures = g_idx.shape[0]
+
+ out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16)
+ numels = out.numel()
+ maxq = 2**bits - 1 if maxq is None else maxq
+ grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),)
+
+ dequant_kernel_248[grid](
+ g_idx,
+ scales,
+ qweight,
+ qzeros,
+ out,
+ numels,
+ maxq=maxq,
+ bits=bits,
+ outfeatures=outfeatures,
+ num_groups=num_groups,
+ )
+ return out
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py
new file mode 100644
index 00000000..d8ed096c
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/layers.py
@@ -0,0 +1,170 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
+
+import logging
+
+import torch
+import torch.nn as nn
+from auto_gptq.nn_modules.qlinear.qlinear_triton import (
+ QuantLinearInferenceOnlyFunction,
+ quant_matmul_inference_only_248,
+ transpose_quant_matmul_248,
+)
+# fixed by aaron.chew1@sg.ibm.com
+from auto_gptq.nn_modules.qlinear.qlinear_tritonv2 import (
+ QuantLinearFunction, quant_matmul_248
+)
+
+logger = logging.getLogger(__name__)
+import math
+
+"""
+For testing only -- replaces HuggingFace default GPTQ QLinear layer (`cuda / torch` -> `triton`)
+"""
+
+
+# Adapted from https://github.com/AutoGPTQ/AutoGPTQ/blob/main/auto_gptq/nn_modules/qlinear/__init__.py
+class GPTQuantLinear(nn.Linear):
+ def __init__(self, quant_linear_module, trainable=True):
+ if hasattr(quant_linear_module, "base_layer"):
+ quant_linear_module = quant_linear_module.base_layer
+
+ bias = (
+ True
+ if hasattr(quant_linear_module, "bias")
+ and quant_linear_module.bias.count_nonzero() > 0
+ else False
+ )
+
+ super().__init__(
+ in_features=quant_linear_module.infeatures,
+ out_features=quant_linear_module.outfeatures,
+ bias=bias,
+ )
+
+ self.infeatures = quant_linear_module.infeatures
+ self.outfeatures = quant_linear_module.outfeatures
+ self.bits = quant_linear_module.bits
+ self.group_size = quant_linear_module.group_size
+ self.maxq = quant_linear_module.maxq
+
+ self.weight.requires_grad = False
+
+ self.weight.data = quant_linear_module.qweight
+ self.register_buffer("qweight", quant_linear_module.qweight)
+ if bias:
+ self.bias.data = quant_linear_module.bias
+ self.bias.requires_grad = False
+
+ self.qweight.requires_grad = False
+
+ self.register_buffer("qzeros", quant_linear_module.qzeros)
+ self.register_buffer("scales", quant_linear_module.scales)
+ self.register_buffer("g_idx", quant_linear_module.g_idx)
+
+ if hasattr(quant_linear_module, "wf"):
+ self.wf = quant_linear_module.wf
+ if hasattr(quant_linear_module, "kernel_switch_threshold"):
+ self.kernel_switch_threshold = quant_linear_module.kernel_switch_threshold
+ if hasattr(quant_linear_module, "autogptq_cuda_available"):
+ self.autogptq_cuda_available = quant_linear_module.autogptq_cuda_available
+
+ self.trainable = trainable
+ self.QUANT_TYPE = "triton"
+
+ def forward(self, x):
+ out_shape = x.shape[:-1] + (self.outfeatures,)
+ quant_linear_fn = (
+ QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
+ )
+ out = quant_linear_fn.apply(
+ x.reshape(-1, x.shape[-1]),
+ self.qweight,
+ self.scales,
+ self.qzeros,
+ self.g_idx,
+ self.bits,
+ self.maxq,
+ )
+ out = out.half().reshape(out_shape)
+ out = out + self.bias if self.bias is not None else out
+
+ return out
+
+ @classmethod
+ def warmup(cls, model, transpose=True, seqlen=2048):
+ """
+ Pre-tunes the quantized kernel
+ """
+ from tqdm import tqdm
+
+ assert cls.QUANT_TYPE == "triton"
+
+ kn_values = {}
+
+ for _, m in model.named_modules():
+ if not isinstance(m, cls):
+ continue
+
+ k = m.infeatures
+ n = m.outfeatures
+
+ if (k, n) not in kn_values:
+ kn_values[(k, n)] = (
+ m.qweight,
+ m.scales,
+ m.qzeros,
+ m.g_idx,
+ m.bits,
+ m.maxq,
+ )
+
+ logger.info(f"Found {len(kn_values)} unique KN Linear values.")
+ logger.info("Warming up autotune cache ...")
+ with torch.no_grad():
+ for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
+ m = 2**m
+ for (k, n), (
+ qweight,
+ scales,
+ qzeros,
+ g_idx,
+ bits,
+ maxq,
+ ) in kn_values.items():
+ if transpose:
+ a = torch.randn(m, k, dtype=torch.float16, device=model.device)
+ quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
+ a = torch.randn(m, n, dtype=torch.float16, device=model.device)
+ transpose_quant_matmul_248(
+ a, qweight, scales, qzeros, g_idx, bits, maxq
+ )
+ else:
+ a = torch.randn(m, k, dtype=torch.float16, device=model.device)
+ quant_matmul_inference_only_248(
+ a, qweight, scales, qzeros, g_idx, bits, maxq
+ )
+ del kn_values
+
+ @classmethod
+ def inject_to_model(cls, model, target_module_type, **kwargs):
+ count = 0
+ for name, m in model.named_modules():
+ if not isinstance(m, target_module_type):
+ continue
+ new_m = cls(m, **kwargs)
+ if "." in name:
+ parent_name = name.rsplit(".", 1)[0]
+ child_name = name[len(parent_name) + 1 :]
+ parent = model.get_submodule(parent_name)
+ else:
+ parent_name = ""
+ parent = model
+ child_name = name
+
+ setattr(parent, child_name, new_m)
+ count += 1
+ logger.warning_once(
+ f"Injected {count} triton qlinear layers in place of {target_module_type} layers."
+ )
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py
new file mode 100644
index 00000000..9c68bd61
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/triton/tuner.py
@@ -0,0 +1,425 @@
+# taken from
+# https://github.com/jeromeku/unsloth/commit/
+# 2839d390ef3bb318904289bfb9a7751a782c4e44
+
+import builtins
+import heapq
+import math
+import time
+from typing import Dict
+
+import triton
+
+# code based on https://github.com/fpgaminer/GPTQ-triton
+"""
+Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
+"""
+
+
+def matmul248_kernel_config_pruner(configs, nargs):
+ """
+ The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
+ """
+ m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
+ n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
+ k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
+
+ used = set()
+ for config in configs:
+ block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
+ block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
+ block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
+ group_size_m = config.kwargs["GROUP_SIZE_M"]
+
+ if (
+ block_size_m,
+ block_size_n,
+ block_size_k,
+ group_size_m,
+ config.num_stages,
+ config.num_warps,
+ ) in used:
+ continue
+
+ used.add(
+ (
+ block_size_m,
+ block_size_n,
+ block_size_k,
+ group_size_m,
+ config.num_stages,
+ config.num_warps,
+ )
+ )
+ yield triton.Config(
+ {
+ "BLOCK_SIZE_M": block_size_m,
+ "BLOCK_SIZE_N": block_size_n,
+ "BLOCK_SIZE_K": block_size_k,
+ "GROUP_SIZE_M": group_size_m,
+ },
+ num_stages=config.num_stages,
+ num_warps=config.num_warps,
+ )
+
+
+CUSTOM_MATMUL_AUTOTUNE_CONFIGS = dict(
+ configs=[
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 256,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 256,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ), # 3090
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ), # 3090
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=2,
+ num_warps=4,
+ ), # 3090
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 16,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ), # 3090
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ), # 3090
+ ],
+ key=["M", "N", "K"],
+ nearest_power_of_two=True,
+ prune_configs_by={
+ "early_config_prune": matmul248_kernel_config_pruner,
+ "perf_model": None,
+ "top_k": None,
+ },
+ warmup=25,
+ rep=40,
+)
+
+CUSTOM_MATMUL_TRANSPOSE_AUTOTUNE_CONFIGS = dict(
+ configs=[
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 256,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=2,
+ num_warps=8,
+ ),
+ ],
+ key=["M", "N", "K"],
+ nearest_power_of_two=True,
+ warmup=25,
+ rep=40,
+)
+
+
+class CustomizedTritonAutoTuner(triton.KernelInterface):
+ def __init__(
+ self,
+ fn,
+ arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ prune_configs_by: Dict = None,
+ nearest_power_of_two: bool = False,
+ warmup=25,
+ rep=40,
+ ):
+ if not configs:
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
+ else:
+ self.configs = configs
+ self.key_idx = [arg_names.index(k) for k in key]
+ self.nearest_power_of_two = nearest_power_of_two
+ self.cache = {}
+ # hook to reset all required tensor to zeros before relaunching a kernel
+ self.hook = lambda args: 0
+ if reset_to_zero is not None:
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
+
+ def _hook(args):
+ for i in self.reset_idx:
+ args[i].zero_()
+
+ self.hook = _hook
+ self.arg_names = arg_names
+ # prune configs
+ if prune_configs_by:
+ perf_model, top_k = (
+ prune_configs_by["perf_model"],
+ prune_configs_by["top_k"],
+ )
+ if "early_config_prune" in prune_configs_by:
+ early_config_prune = prune_configs_by["early_config_prune"]
+ else:
+ perf_model, top_k, early_config_prune = None, None, None
+ self.perf_model, self.configs_top_k = perf_model, top_k
+ self.early_config_prune = early_config_prune
+ self.fn = fn
+ self.warmup = warmup
+ self.rep = rep
+
+ def _bench(self, *args, config, **meta):
+ # check for conflicts, i.e. meta-parameters both provided
+ # as kwargs and by the autotuner
+ conflicts = meta.keys() & config.kwargs.keys()
+ if conflicts:
+ raise ValueError(
+ f"Conflicting meta-parameters: {', '.join(conflicts)}."
+ " Make sure that you don't re-define auto-tuned symbols."
+ )
+ # augment meta-parameters with tunable ones
+ current = dict(meta, **config.kwargs)
+
+ def kernel_call():
+ if config.pre_hook:
+ config.pre_hook(self.nargs)
+ self.hook(args)
+ self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **current,
+ )
+
+ try:
+ # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
+ # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
+ return triton.testing.do_bench(
+ kernel_call, quantiles=(0.5, 0.2, 0.8), rep=self.rep, warmup=self.warmup
+ )
+ except triton.OutOfResources:
+ return (float("inf"), float("inf"), float("inf"))
+
+ def run(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ if len(self.configs) > 1:
+ key = tuple(args[i] for i in self.key_idx)
+
+ # This reduces the amount of autotuning by rounding the keys to the nearest power of two
+ # In my testing this gives decent results, and greatly reduces the amount of tuning required
+ if self.nearest_power_of_two:
+ key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
+
+ if key not in self.cache:
+ # prune configs
+ pruned_configs = self.prune_configs(kwargs)
+ bench_start = time.time()
+ timings = {
+ config: self._bench(*args, config=config, **kwargs)
+ for config in pruned_configs
+ }
+ bench_end = time.time()
+ self.bench_time = bench_end - bench_start
+ self.cache[key] = builtins.min(timings, key=timings.get)
+ self.hook(args)
+ self.configs_timings = timings
+ config = self.cache[key]
+ else:
+ config = self.configs[0]
+ self.best_config = config
+ if config.pre_hook is not None:
+ config.pre_hook(self.nargs)
+ return self.fn.run(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+
+ def prune_configs(self, kwargs):
+ pruned_configs = self.configs
+ if self.early_config_prune:
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
+ if self.perf_model:
+ top_k = self.configs_top_k
+ if isinstance(top_k, float) and top_k <= 1.0:
+ top_k = int(len(self.configs) * top_k)
+ if len(pruned_configs) > top_k:
+ est_timing = {
+ config: self.perf_model(
+ **self.nargs,
+ **kwargs,
+ **config.kwargs,
+ num_stages=config.num_stages,
+ num_warps=config.num_warps,
+ )
+ for config in pruned_configs
+ }
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
+ :top_k
+ ]
+ return pruned_configs
+
+ def warmup(self, *args, **kwargs):
+ self.nargs = dict(zip(self.arg_names, args))
+ for config in self.prune_configs(kwargs):
+ self.fn.warmup(
+ *args,
+ num_warps=config.num_warps,
+ num_stages=config.num_stages,
+ **kwargs,
+ **config.kwargs,
+ )
+ self.nargs = None
+
+
+def custom_autotune(
+ configs,
+ key,
+ prune_configs_by=None,
+ reset_to_zero=None,
+ nearest_power_of_two=False,
+ warmup=25,
+ rep=40,
+):
+ def decorator(fn):
+ return CustomizedTritonAutoTuner(
+ fn,
+ fn.arg_names,
+ configs,
+ key,
+ reset_to_zero,
+ prune_configs_by,
+ nearest_power_of_two,
+ warmup=warmup,
+ rep=rep,
+ )
+
+ return decorator
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py
new file mode 100644
index 00000000..fca96782
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/swiglu.py
@@ -0,0 +1,98 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+import torch
+
+
+@triton.jit
+def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ # f = e * sigmoid(e)
+ f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
+ f_row = f_row.to(g_row.dtype) # Exact copy from HF
+ # h = f * g
+ h_row = f_row * g_row
+
+ # Store h
+ tl.store(h + offsets, h_row, mask = mask)
+pass
+
+
+def swiglu_fg_kernel(e, g):
+ batch, seq_len, hd = e.shape
+ n_elements = e.numel()
+ h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda")
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
+ return h
+pass
+
+
+@triton.jit
+def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
+ """
+ e = e.float()
+ se = 1.0 / (1.0 + torch.exp(-e))
+ f = (se * e).to(dtype)
+ h = f * g
+ df = DW * f
+ dg = DW * g
+ de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
+ """
+ block_idx = tl.program_id(0)
+ offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+
+ DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
+ e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
+ g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ # e = e.float()
+ # se = 1.0 / (1.0 + torch.exp(-e))
+ se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
+ # f = (se * e).to(dtype)
+ f_row = se_row * e_row
+ f_row = f_row.to(DW_row.dtype)
+ # h = f * g
+ h_row = f_row * g_row
+ # df = DW * f
+ df_row = DW_row * f_row
+ # dg = DW * g
+ dg_row = DW_row * g_row
+ # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
+ de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
+ de_row = de_row.to(DW_row.dtype)
+
+ # Store derivatives in buffers
+ tl.store(DW + offsets, h_row, mask = mask) # h = f * g
+ tl.store(e + offsets, df_row, mask = mask) # df = DW * f
+ tl.store(g + offsets, de_row, mask = mask) # de
+pass
+
+
+def swiglu_DWf_DW_dfg_kernel(DW, e, g):
+ batch_seq_len, hd = e.shape
+ n_elements = e.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
+ _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
+ return DW, e, g
+pass
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py
new file mode 100644
index 00000000..5354670f
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py
@@ -0,0 +1,255 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+MAX_FUSED_SIZE = 65536
+next_power_of_2 = triton.next_power_of_2
+
+def calculate_settings(n):
+ BLOCK_SIZE = next_power_of_2(n)
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
+ num_warps = 4
+ if BLOCK_SIZE >= 32768: num_warps = 32
+ elif BLOCK_SIZE >= 8192: num_warps = 16
+ elif BLOCK_SIZE >= 2048: num_warps = 8
+ return BLOCK_SIZE, num_warps
+pass
+
+# import guard added by flim@sg.ibm.com
+from transformers.utils.import_utils import _bitsandbytes_available
+if _bitsandbytes_available:
+ import bitsandbytes as bnb
+ get_ptr = bnb.functional.get_ptr
+ import ctypes
+ import torch
+ cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
+ cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
+ cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
+ cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
+ cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
+
+# modified by flim@sg.ibm.com
+def QUANT_STATE(W, base_layer):
+
+ # if the weights has quant_state just take it from there
+ if hasattr(W, 'quant_state'):
+ return W.quant_state
+
+ # otherwise fall back to checking if it is on the base layer
+ # This is needed when FSDP shards the parameters, and destroys the original
+ # weight matrix, so we can get the quant state back
+ return getattr(base_layer, 'quant_state', None)
+pass
+
+# modified by flim@sg.ibm.com
+def get_lora_parameters(proj):
+ # For DPO or disabled adapters
+ base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
+ W = base_layer.weight
+
+ if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
+ return W, QUANT_STATE(W, base_layer), None, None, None
+ pass
+
+ active_adapter = proj.active_adapters[0] if \
+ hasattr(proj, "active_adapters") else proj.active_adapter
+ A = proj.lora_A [active_adapter].weight
+ B = proj.lora_B [active_adapter].weight
+ s = proj.scaling[active_adapter]
+ return W, QUANT_STATE(W, base_layer), A, B, s
+pass
+
+
+def fast_dequantize(W, quant_state = None, out = None):
+ if quant_state is None: return W
+ if type(quant_state) is not list:
+ # New quant_state as a class
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
+ blocksize2 = state2.blocksize
+ else:
+ # Old quant_state as a list of lists
+ absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
+ offset, state2 = compressed_stats
+ absmax2, code2, blocksize2, _, _, _, _ = state2
+ pass
+
+ # Create weight matrix
+ if out is None:
+ out = torch.empty(shape, dtype = dtype, device = "cuda")
+ else:
+ assert(out.shape == shape)
+ assert(out.dtype == dtype)
+
+ # NF4 dequantization of statistics
+ n_elements_absmax = absmax.numel()
+ out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda")
+
+ # Do dequantization
+ ptr_out_absmax = get_ptr(out_absmax)
+ cdequantize_blockwise_fp32(
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
+ ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax)
+ )
+ out_absmax += offset
+
+ fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
+ cdequantize_blockwise_bf16_nf4
+ fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
+ ctypes.c_int(blocksize), ctypes.c_int(out.numel()))
+
+ # Careful returning transposed data
+ is_transposed = (True if W.shape[0] == 1 else False)
+ return out.t() if is_transposed else out
+pass
+
+
+def fast_gemv(X, W, quant_state, out = None):
+ if quant_state is None: return torch.matmul(X, W, out = out)
+ # For fast X @ W where seq_len == 1
+ # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
+ _, q_len, hd = X.shape
+ # assert(q_len == 1)
+
+ if type(quant_state) is not list:
+ # https://github.com/TimDettmers/bitsandbytes/pull/763/files
+ absmax = quant_state.absmax
+ shape = quant_state.shape
+ dtype = quant_state.dtype
+ blocksize = quant_state.blocksize
+ stats = quant_state.code
+ offset = quant_state.offset
+ state2 = quant_state.state2
+ absmax2 = state2.absmax
+ code2 = state2.code
+ blocksize2 = state2.blocksize
+ else:
+ absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
+ offset, state2 = compressed_stats
+ absmax2, code2, blocksize2, _, _, _, _ = state2
+ pass
+ # assert(dtype == X.dtype)
+ bout = shape[0]
+
+ if out is None:
+ out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda")
+ # else:
+ # assert(out.shape == (1, 1, bout,))
+ # pass
+
+ n = 1
+ m = shape[0]
+ k = shape[1]
+ lda = shape[0]
+ ldc = shape[0]
+ ldb = (hd+1)//2
+ m = ctypes.c_int32(m)
+ n = ctypes.c_int32(n)
+ k = ctypes.c_int32(k)
+ lda = ctypes.c_int32(lda)
+ ldb = ctypes.c_int32(ldb)
+ ldc = ctypes.c_int32(ldc)
+
+ df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda")
+ cdequantize_blockwise_fp32(
+ get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
+ ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
+ )
+ df += offset
+ absmax = df
+
+ fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
+ cgemm_4bit_inference_naive_bf16
+
+ blocksize = ctypes.c_int32(blocksize)
+ fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
+ lda, ldb, ldc, blocksize)
+
+ return out
+pass
+
+
+def fast_linear_forward(proj, X, temp_lora = None, out = None):
+
+ W, W_quant, lora_A, lora_B, lora_S = get_lora_parameters(proj)
+ bsz, q_len, in_dim = X.shape
+ if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
+
+ if W_quant is None:
+ out = torch.matmul(X, W.t(), out = out)
+ elif bsz == 1 and q_len == 1:
+ out = fast_gemv(X, W, W_quant, out = out)
+ else:
+ W = fast_dequantize(W.t(), W_quant)
+ out = torch.matmul(X, W, out = out)
+ pass
+
+ # Add in LoRA weights
+ if lora_A is not None:
+ out_dim = out.shape[2]
+ dtype = X.dtype
+
+ if not hasattr(lora_A, "_fast_lora"):
+ lora_A._fast_lora = lora_A.to(dtype)
+ lora_B._fast_lora = lora_B.to(dtype)
+ pass
+
+ if bsz == 1:
+ out = out.view(out_dim)
+ temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
+ out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
+ else:
+ out = out.view(bsz, out_dim)
+ temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
+ out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
+ pass
+ out = out.view(bsz, 1, out_dim)
+ pass
+
+ return out
+pass
+
+
+def matmul_lora(X, W, W_quant, A, B, s, out = None):
+ dtype = X.dtype
+ W = fast_dequantize(W.t(), W_quant)
+
+ if X.dim() == 3:
+ batch, seq_len, d = X.shape
+ X = X.view(-1, X.shape[-1])
+ reshape = True
+ else:
+ reshape = False
+ pass
+
+ out = torch.matmul(X, W, out = out)
+ if W_quant is not None: del W
+
+ if A is not None:
+ # LoRA is enabled
+ A, B = A.t(), B.t()
+ out += (X @ A.to(dtype)) @ (s * B.to(dtype))
+ pass
+
+ return out.view(batch, seq_len, -1) if reshape else out
+pass
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py
new file mode 100644
index 00000000..b994759e
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py
new file mode 100644
index 00000000..0c5c2706
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .cross_entropy_loss import fast_cross_entropy_loss
+from .rms_layernorm import fast_rms_layernorm
+from .rope_embedding import fast_rope_embedding
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py
new file mode 100644
index 00000000..ebf6f3d0
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/cross_entropy_loss.py
@@ -0,0 +1,292 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+import torch
+from .utils import calculate_settings, MAX_FUSED_SIZE
+
+
+@triton.jit
+def _cross_entropy_forward(
+ logits_ptr, logits_row_stride,
+ loss_ptr,
+ logsumexp_ptr,
+ labels_ptr,
+ VOCAB_SIZE : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+):
+ """
+ Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
+ Pi = exp(xi) / sum(exp(xi))
+ CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
+ = -y [ x - log[sum(exp(x))] ]
+ = y * (log[sum(exp(x))] - x)
+ If y == 0: CE_i = 0
+ If y == 1: CE_i = logsumexp - x
+
+ logsumexp is also stable
+ Take y = log[sum(exp(x))]
+ exp(y) = sum(exp(x))
+ exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
+ exp(y) = exp(c)*sum(exp(x - c))
+ y = log(exp(c)*sum(exp(x - c)))
+ y = c + log[sum(exp(x - c))]
+ This means we can set c = max(x) to make sure
+ exp(x - c) always is exp(x - max(x)).
+ This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
+ """
+ row_idx = tl.program_id(0)
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
+ loss_ptr += row_idx
+ logsumexp_ptr += row_idx
+ labels_ptr += row_idx
+
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < VOCAB_SIZE
+
+ label_idx = tl.load(labels_ptr).to(tl.int32)
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
+ c = tl.max(logits, 0)
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
+
+ if label_idx != -100:
+ x = tl.load(logits_ptr + label_idx).to(tl.float32)
+ loss = logsumexp - x
+ else:
+ loss = 0.0
+ tl.store(logsumexp_ptr, logsumexp)
+ tl.store(loss_ptr, loss)
+pass
+
+
+@triton.jit
+def _chunked_cross_entropy_forward(
+ logits_ptr, logits_row_stride,
+ loss_ptr,
+ logsumexp_ptr,
+ labels_ptr,
+ VOCAB_SIZE : tl.constexpr,
+ N_CHUNKS : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+):
+ """
+ 256K vocab divided in 4 chunks
+
+ |-65536-| |-65536-| |-65536-| |-65536-|
+ |-------| |-------| |-------| |-------|
+ |-------| |-------| |-------| |-------|
+
+ If y == 0: CE_i = 0
+ If y == 1: CE_i = logsumexp - x
+
+ Notice we can do logsumexp for each chunk and then
+ logsumexp[chunk_sum(logsumexp)] == logsumexp
+
+ chunk_sum = log[chunk_sum(logsumexp)]
+ = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
+ = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
+ = log[sum(exp(a)) + ... + sum(exp(z))]
+ = logsumexp(x)
+
+ This means we can perform a logsumexp for each chunk, then do a
+ final logsumexp reduction!
+
+ Ie do: logsumexp(chunked_logsumexp) - x
+ """
+ row_idx = tl.program_id(0)
+ chunk_idx = tl.program_id(1)
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
+ loss_ptr += row_idx
+ logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
+ labels_ptr += row_idx
+
+ col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < VOCAB_SIZE
+
+ label_idx = tl.load(labels_ptr).to(tl.int32)
+ logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
+ c = tl.max(logits, 0)
+ logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
+
+ if chunk_idx == 0:
+ # logsumexp(chunked_logsumexp) - x
+ # Do the -x separately
+ if label_idx != -100:
+ x = tl.load(logits_ptr + label_idx).to(tl.float32)
+ loss = -1.0 * x
+ else:
+ loss = 0.0
+ tl.store(loss_ptr, loss)
+ pass
+ tl.store(logsumexp_ptr, logsumexp)
+pass
+
+
+@triton.jit
+def _cross_entropy_backward(
+ logits_ptr, logits_row_stride,
+ dloss_ptr, dloss_row_stride,
+ logsumexp_ptr,
+ labels_ptr,
+ VOCAB_SIZE : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+):
+ """
+ CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
+ dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
+
+ From https://en.wikipedia.org/wiki/LogSumExp
+ d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
+
+ dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
+ dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
+ dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
+
+ If y == 0: dC/dx = 0
+ If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
+ If y == 1 and x != label: dC/dx = exp[x - logsumexp]
+ """
+ row_idx = tl.program_id(0)
+ block_idx = tl.program_id(1)
+
+ logits_ptr += row_idx * logits_row_stride.to(tl.int64)
+ dloss_ptr += row_idx * dloss_row_stride
+ col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < VOCAB_SIZE
+ label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
+
+ if label_idx != -100:
+ dloss = tl.load(dloss_ptr)
+ else:
+ dloss = 0.0
+
+ x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
+ logsumexp = tl.load(logsumexp_ptr + row_idx)
+ y = tl.exp(x - logsumexp)
+ y = tl.where(
+ col_offsets == label_idx,
+ y - 1.0, # exp(x - logsumexp) - 1
+ y, # exp(x - logsumexp)
+ )
+
+ # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
+ tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
+pass
+
+
+MAX_FUSED_SIZE = 65536 # 2**16
+
+class Fast_CrossEntropyLoss(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, logits, labels):
+ n_rows, vocab_size = logits.shape
+
+ div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
+ n_chunks = div + (mod != 0)
+ losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
+
+ if n_chunks == 1:
+ # For small vocabs <= 65336 like Llama, Mistral
+ BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
+ logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
+
+ _cross_entropy_forward[(n_rows,)](
+ logits, logits.stride(0),
+ losses,
+ logsumexp,
+ labels,
+ VOCAB_SIZE = vocab_size,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
+ )
+ else:
+ # For large vocabs > 65336 like Gemma 256K
+ logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda")
+
+ _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
+ logits, logits.stride(0),
+ losses,
+ logsumexp,
+ labels,
+ VOCAB_SIZE = vocab_size,
+ N_CHUNKS = n_chunks,
+ BLOCK_SIZE = MAX_FUSED_SIZE,
+ num_warps = 32,
+ )
+ # logsumexp(chunked_logsumexp) - x
+ # Do the -x separately
+ logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
+ losses += logsumexp
+ losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
+ pass
+
+ ctx.save_for_backward(logits, logsumexp, labels)
+ return losses
+ pass
+
+ @staticmethod
+ def backward(ctx, dlosses):
+ logits, logsumexp, labels = ctx.saved_tensors
+ n_rows, vocab_size = logits.shape
+
+ BLOCK_SIZE = 4096
+ div, mod = divmod(vocab_size, BLOCK_SIZE)
+ n_blocks = div + (mod != 0)
+
+ _cross_entropy_backward[(n_rows, n_blocks,)](
+ logits, logits.stride(0),
+ dlosses, dlosses.stride(0),
+ logsumexp,
+ labels,
+ VOCAB_SIZE = vocab_size,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = 8,
+ )
+ return logits, None, None,
+ pass
+pass
+
+
+def fast_cross_entropy_loss(logits, labels):
+ """
+ Arguments:
+ logits: (batch, seq_len, vocab_size)
+ labels: (batch, seq_len,)
+ Returns:
+ losses: float
+ """
+ batch, seq_len, d = logits.shape
+ assert(labels.shape == (batch, seq_len))
+
+ loss = Fast_CrossEntropyLoss.apply(
+ logits.view(batch*seq_len, d),
+ labels.view(-1),
+ )
+ n_items = torch.count_nonzero(labels != -100)
+ return loss.sum() / n_items
+pass
+
+# added by flim@sg.ibm.com
+class FastCrossEntropyLoss(torch.nn.CrossEntropyLoss):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input, target):
+ loss = Fast_CrossEntropyLoss.apply(
+ input, target
+ )
+ n_items = torch.count_nonzero(target != -100)
+ return loss.sum() / n_items
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py
new file mode 100644
index 00000000..4db89b78
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rms_layernorm.py
@@ -0,0 +1,192 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+import torch
+from .utils import calculate_settings
+
+
+@triton.jit
+def _rms_layernorm_forward(
+ Y, Y_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride,
+ n_cols, eps,
+ BLOCK_SIZE : tl.constexpr
+):
+ """
+ Fast RMS Layernorm kernel
+ Inspiration from a Triton tutorial:
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+ """
+ row_idx = tl.program_id(0)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < n_cols
+
+ Y += row_idx * Y_row_stride
+ X += row_idx * X_row_stride
+ r += row_idx * r_row_stride
+
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
+
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
+ inv_var = tl.math.rsqrt(row_var + eps)
+ tl.store(r, inv_var)
+ normed = X_row * inv_var
+ normed = normed.to(W_row.dtype) # Exact copy from HF
+ output = normed * W_row
+ tl.store(Y + col_offsets, output, mask = mask)
+pass
+
+
+@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
+@triton.jit
+def _rms_layernorm_backward(
+ dY, dY_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride,
+ dW, dW_row_stride,
+ n_cols, eps,
+ GEMMA : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+):
+ """
+ Fast RMS Layernorm kernel for the backward pass
+ Inspiration from a Triton tutorial:
+ https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+ """
+ row_idx = tl.program_id(0)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < n_cols
+
+ dY += row_idx * dY_row_stride
+ X += row_idx * X_row_stride
+ r += row_idx * r_row_stride
+
+ dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+
+ # Get saved row variance
+ inv_var = tl.load(r).to(tl.float32)
+ normed = X_row * inv_var
+
+ if GEMMA: dY_W = dY_row * (W_row + 1.0)
+ else: dY_W = dY_row * W_row
+
+ rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
+ output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
+ tl.store(dY + col_offsets, output, mask = mask)
+pass
+
+
+@triton.jit
+def _gemma_rms_layernorm_forward(
+ Y, Y_row_stride,
+ X, X_row_stride,
+ W, W_row_stride,
+ r, r_row_stride,
+ n_cols, eps,
+ BLOCK_SIZE : tl.constexpr,
+):
+ # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
+ # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
+ # exactly. Essentially all in float32!
+ row_idx = tl.program_id(0)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ mask = col_offsets < n_cols
+
+ Y += row_idx * Y_row_stride
+ X += row_idx * X_row_stride
+ r += row_idx * r_row_stride
+
+ X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+ W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+
+ row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
+ inv_var = 1.0 / tl.sqrt(row_var + eps) # Must be 1/sqrt to match Deepmind's impl
+ tl.store(r, inv_var)
+ normed = X_row * inv_var
+ output = normed * (W_row + 1.0)
+
+ tl.store(Y + col_offsets, output, mask = mask)
+pass
+
+
+class Fast_RMS_Layernorm(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, X, W, eps, gemma = False):
+ shape = X.shape
+ dim = shape[-1]
+ X = X.view(-1, dim)
+ n_rows, n_cols = X.shape
+ BLOCK_SIZE, num_warps = calculate_settings(n_cols)
+
+ Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
+ r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
+
+ fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
+ fx[(n_rows,)](
+ Y, Y.stride(0),
+ X, X.stride(0),
+ W, W.stride(0),
+ r, r.stride(0),
+ n_cols, eps,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
+ )
+ ctx.eps = eps
+ ctx.BLOCK_SIZE = BLOCK_SIZE
+ ctx.num_warps = num_warps
+ ctx.GEMMA = gemma
+ ctx.save_for_backward(X, W, r)
+ return Y.view(*shape)
+ pass
+
+ @staticmethod
+ def backward(ctx, dY):
+ shape = dY.shape
+ dim = shape[-1]
+ dY = dY.view(-1, dim)
+ X, W, r = ctx.saved_tensors
+ n_rows, n_cols = dY.shape
+ dW = X
+
+ _rms_layernorm_backward[(n_rows,)](
+ dY, dY.stride(0),
+ X, X .stride(0),
+ W, W .stride(0),
+ r, r .stride(0),
+ dW, dW.stride(0),
+ n_cols, ctx.eps,
+ GEMMA = ctx.GEMMA,
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
+ num_warps = ctx.num_warps,
+ )
+ dX = dY.view(*shape)
+ return dX, None, None, None
+ pass
+pass
+
+
+def fast_rms_layernorm(layernorm, X, gemma = False):
+ W = layernorm.weight
+ eps = layernorm.variance_epsilon
+ out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
+ return out
+pass
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py
new file mode 100644
index 00000000..3577b586
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/rope_embedding.py
@@ -0,0 +1,139 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+import triton.language as tl
+import torch
+from .utils import calculate_settings
+
+ROPE_GROUP_SIZE = 4
+
+@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
+@triton.jit
+def _rope_embedding(
+ Q, Q_row_stride,
+ cos, cos_row_stride,
+ sin, sin_row_stride,
+ seqlen,
+ head_dim : tl.constexpr,
+ n_heads : tl.constexpr,
+ BACKWARD_PASS : tl.constexpr,
+ BLOCK_SIZE : tl.constexpr,
+):
+ """
+ Calculates the RoPE Embedding quickly
+ RoPE is Q * cos + rotate_half(Q) * sin
+ See our blog post for more info
+ """
+ row_position = tl.program_id(0)
+ group_head_position = tl.program_id(1)
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ half_head_dim = head_dim // 2
+ mask = col_offsets < half_head_dim
+
+ sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
+ cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
+ half_head_dim*0 + col_offsets, mask = mask, other = 0)
+
+ if BACKWARD_PASS:
+ # See our blog post for more info.
+ sin1 = -sin1
+ pass
+
+ # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
+ head_start = group_head_position * ROPE_GROUP_SIZE
+ head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
+
+ # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
+ for k in range(head_start, head_end):
+ offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
+ offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
+
+ # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
+ Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
+ Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
+
+ tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
+ tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
+ pass
+pass
+
+
+class Fast_RoPE_Embedding(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, Q, cos, sin):
+ cos, sin = cos.squeeze(), sin.squeeze()
+ batch, seq_len, n_heads, head_dim = Q.shape
+ Q = Q.view(batch*seq_len, n_heads*head_dim)
+ n_rows, n_cols = Q.shape
+ assert(seq_len <= cos.shape[0])
+
+ # [TODO] Changing blocksize to head_dim//2 seems to have
+ # some concurrency / un-deterministic issues.
+ BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
+
+ # group_size = 4 # 4 or 8, too large group_size can hurt performance.
+ div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
+ n_groups = div + (mod != 0)
+
+ _rope_embedding[(n_rows, n_groups, )](
+ Q, Q.stride(0),
+ cos, cos.stride(0),
+ sin, sin.stride(0),
+ seq_len,
+ head_dim, n_heads,
+ BACKWARD_PASS = False,
+ BLOCK_SIZE = BLOCK_SIZE,
+ num_warps = num_warps,
+ )
+ ctx.BLOCK_SIZE = BLOCK_SIZE
+ ctx.num_warps = num_warps
+ ctx.n_groups = n_groups
+ ctx.cos = cos
+ ctx.sin = sin
+ return Q.view(batch, seq_len, n_heads, head_dim)
+ pass
+
+ @staticmethod
+ def backward(ctx, dY):
+ batch, seq_len, n_heads, head_dim = dY.shape
+ dY = dY.reshape(batch*seq_len, n_heads*head_dim)
+ # Must be reshape not view
+ n_rows, n_cols = dY.shape
+
+ cos = ctx.cos
+ sin = ctx.sin
+
+ _rope_embedding[(n_rows, ctx.n_groups, )](
+ dY, dY .stride(0),
+ cos, cos.stride(0),
+ sin, sin.stride(0),
+ seq_len, head_dim, n_heads,
+ BACKWARD_PASS = True,
+ BLOCK_SIZE = ctx.BLOCK_SIZE,
+ num_warps = ctx.num_warps,
+ )
+ dY = dY.view(batch, seq_len, n_heads, head_dim)
+ return dY, None, None,
+ pass
+pass
+
+# modified by flim@sg.ibm.com
+# NOTE: fast_rope embeddings currently does not account for position ids
+def fast_rope_embedding(Q, K, cos, sin, position_ids=None):
+ Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
+ K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
+ return Q, K
+pass
\ No newline at end of file
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py
new file mode 100644
index 00000000..8d4aa881
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/unsloth/utils.py
@@ -0,0 +1,29 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import triton
+MAX_FUSED_SIZE = 65536
+next_power_of_2 = triton.next_power_of_2
+
+def calculate_settings(n):
+ BLOCK_SIZE = next_power_of_2(n)
+ if BLOCK_SIZE > MAX_FUSED_SIZE:
+ raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
+ f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
+ num_warps = 4
+ if BLOCK_SIZE >= 32768: num_warps = 32
+ elif BLOCK_SIZE >= 8192: num_warps = 16
+ elif BLOCK_SIZE >= 2048: num_warps = 8
+ return BLOCK_SIZE, num_warps
+pass
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py
new file mode 100644
index 00000000..ebd49924
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/__init__.py
@@ -0,0 +1,24 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Local
+from .model_patcher import ModelPatcher
+
+PATCHES = [".models.llama", ".models.mistral", ".models.mixtral"]
+PLUGIN_PREFIX = "fms_acceleration_foak"
+
+# TODO: remove the need for the prefix
+ModelPatcher.load_patches(
+ [f"{PLUGIN_PREFIX}{postfix}" for postfix in PATCHES],
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
new file mode 100644
index 00000000..290d1217
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py
@@ -0,0 +1,134 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from functools import partial
+
+# Third Party
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaMLP,
+ LlamaRMSNorm,
+)
+
+# Local
+from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
+from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
+
+# TODO: have a generic version of this rule
+# - do regex on RMSNorm class name
+# - check on the tensors required for fast_rms_layernorm
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-rms",
+ trigger=ModelPatcherTrigger(check=LlamaRMSNorm),
+ forward=fast_rms_layernorm,
+ ),
+)
+
+# TODO: have a generic version of this rule
+# - do regex on Attention class name
+# - have a set of qkv / o module names and check on that
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=LlamaAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=LlamaAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-mlp",
+ trigger=ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=LlamaMLP,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ )
+ ),
+ forward_builder=partial(
+ build_lora_fused_ops,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ fused_op=KEY_MLP,
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+# TODO: have a generic version of this rule
+# - get the module_name and reload on that
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-cross-ent",
+ import_and_maybe_reload=(
+ "torch.nn.CrossEntropyLoss",
+ FastCrossEntropyLoss,
+ "transformers.models.llama.modeling_llama",
+ ),
+ )
+)
+
+# TODO: have a generic version of this rule
+# - get the module name
+# - check if "apply_rotary_pos_emb" exists
+# - patch
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="llama-rope",
+ import_and_maybe_reload=(
+ "transformers.models.llama.modeling_llama.apply_rotary_pos_emb",
+ fast_rope_embedding,
+ None,
+ ),
+ )
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py
new file mode 100644
index 00000000..37809fd1
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py
@@ -0,0 +1,124 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from functools import partial
+
+# Third Party
+from transformers.models.mistral.modeling_mistral import (
+ MistralAttention,
+ MistralMLP,
+ MistralRMSNorm,
+)
+
+# Local
+from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
+from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
+
+# - do regex on RMSNorm class name
+# - check on the tensors required for fast_rms_layernorm
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-rms",
+ trigger=ModelPatcherTrigger(check=MistralRMSNorm),
+ forward=fast_rms_layernorm,
+ ),
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MistralAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MistralAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-mlp",
+ trigger=ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MistralMLP,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ )
+ ),
+ forward_builder=partial(
+ build_lora_fused_ops,
+ submodule_names=["up_proj", "down_proj", "gate_proj"],
+ fused_op=KEY_MLP,
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-cross-ent",
+ import_and_maybe_reload=(
+ "torch.nn.CrossEntropyLoss",
+ FastCrossEntropyLoss,
+ "transformers.models.mistral.modeling_mistral",
+ ),
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mistral-rope",
+ import_and_maybe_reload=(
+ "transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb",
+ fast_rope_embedding,
+ None,
+ ),
+ )
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py
new file mode 100644
index 00000000..1522ef8d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py
@@ -0,0 +1,104 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from functools import partial
+
+# Third Party
+from transformers.models.mixtral.modeling_mixtral import (
+ MixtralAttention,
+ MixtralRMSNorm,
+)
+
+# Local
+from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss
+from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
+from ..kernels.unsloth.rope_embedding import fast_rope_embedding
+from .model_patcher import (
+ ModelPatcher,
+ ModelPatcherRule,
+ ModelPatcherTrigger,
+ combine_functions,
+ combine_triggers,
+)
+from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops
+
+# - do regex on RMSNorm class name
+# - check on the tensors required for fast_rms_layernorm
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-rms",
+ trigger=ModelPatcherTrigger(check=MixtralRMSNorm),
+ forward=fast_rms_layernorm,
+ ),
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-qkvo",
+ trigger=combine_triggers(
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MixtralAttention,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ )
+ ),
+ ModelPatcherTrigger(
+ check=partial(
+ trigger_fused_ops,
+ attn_cls=MixtralAttention,
+ submodule_names=["o_proj"],
+ )
+ ),
+ logic="OR",
+ ),
+ forward_builder=combine_functions(
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["q_proj", "k_proj", "v_proj"],
+ fused_op=KEY_QKV,
+ ),
+ partial(
+ build_lora_fused_ops,
+ submodule_names=["o_proj"],
+ fused_op=KEY_O,
+ ),
+ logic="APPEND",
+ ),
+ forward_builder_args=["base_type"],
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-cross-ent",
+ import_and_maybe_reload=(
+ "torch.nn.CrossEntropyLoss",
+ FastCrossEntropyLoss,
+ "transformers.models.mixtral.modeling_mixtral",
+ ),
+ )
+)
+
+ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id="mixtral-rope",
+ import_and_maybe_reload=(
+ "transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb",
+ fast_rope_embedding,
+ None,
+ ),
+ )
+)
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py
new file mode 100644
index 00000000..7f803330
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/model_patcher.py
@@ -0,0 +1,495 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Standard
+from dataclasses import asdict, dataclass
+from enum import Enum
+from types import MethodType
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
+import importlib
+import inspect
+
+# Third Party
+import pandas as pd
+import torch
+
+# ------------------------ helpers -----------------------
+
+
+def _patch_target_module(
+ to_patch: str,
+ replace_with: Any,
+ target_module: str = None,
+):
+ to_patch = to_patch.split(".")
+ assert len(to_patch) > 1, "must have an object to patch"
+
+ to_patch, obj_name_to_patch = to_patch[:-1], to_patch[-1]
+ to_patch = ".".join(to_patch)
+ source = importlib.import_module(to_patch)
+ original_obj = getattr(source, obj_name_to_patch)
+ setattr(source, obj_name_to_patch, replace_with)
+
+ if target_module is not None:
+ # reload and this should get the patched object
+ target_module = importlib.import_module(target_module)
+ importlib.reload(target_module)
+
+ # replace it
+ setattr(source, obj_name_to_patch, original_obj)
+
+
+# ------------------------ classes -----------------------
+
+# Rules will trigger on either
+# - module class, which triggers on isinstance
+# - callable, which will be useful to trigger on custom checks
+# - (consider): adding a regex will will apply on the name
+# ModelPatcherTrigger = Union[
+# torch.nn.Module, # trigger on isinstance
+# Callable[[torch.nn.Module], bool] # trigger on callable
+# ]
+# NOTE: triggering on instance checks will not be robust to reloading
+
+
+class ModelPatcherTriggerType(Enum):
+ module = 1
+ callable = 2
+
+
+@dataclass
+class ModelPatcherTrigger:
+ "Holds the triggering logic for the model patcher rule."
+
+ # the trigger operation
+ check: Union[
+ torch.nn.Module, # trigger on isinstance
+ Callable[[torch.nn.Module], bool], # trigger on callable
+ ]
+
+ # holds the type of the trigger
+ # - type is None that it will be a single call
+ type: ModelPatcherTriggerType = None
+
+ # if the trigger is specific to model name
+ module_name: str = None
+
+ def is_triggered(
+ self,
+ module: torch.nn.Module,
+ module_name: str,
+ ):
+ "Check if trigger returns truthful."
+
+ if self.module_name is not None and module_name != self.module_name:
+ return False
+
+ if self.type == ModelPatcherTriggerType.module and isinstance(
+ module, self.check
+ ):
+ return True
+
+ try:
+ # the function call may raise
+ if self.type == ModelPatcherTriggerType.callable and self.check(module):
+ return True
+ except Exception: # pylint: disable=broad-exception-caught
+ # NOTE: not sure if its good idea to let the exception pass through
+ pass
+
+ return False
+
+ def __post_init__(self):
+
+ if self.type is None:
+ if inspect.isclass(self.check) and issubclass(self.check, torch.nn.Module):
+ self.type = ModelPatcherTriggerType.module
+ else:
+ self.type = ModelPatcherTriggerType.callable
+
+
+# type for model forward
+ModelForward = Callable
+
+
+@dataclass
+class ModelPatcherRule:
+ # id, must be unique
+ rule_id: str
+
+ # trigger
+ # - if trigger is none, then it will be a model file patching
+ trigger: ModelPatcherTrigger = None
+
+ # takes in the torch module to build the forward.
+ # will be helpful to
+ # - do any pre-modification on the torch module
+
+ # this is mutually exclusive from forward_builder
+ forward: ModelForward = None
+
+ # returns either
+ # - a callable, which will be patched on the triggered module
+ # - a list of trigger-forward tuples
+ forward_builder: Callable[
+ [torch.nn.Module],
+ Union[ModelForward, List[Tuple[ModelPatcherTrigger, ModelForward]]],
+ ] = None
+
+ # if specified, these will be passed on frrom ModelPatcher.patch
+ # (if they exist)
+ forward_builder_args: List[str] = None
+
+ # this is mutually exclusive from forward and forward builder
+ import_and_maybe_reload: Tuple[
+ str, # path to the object to be patched (e.g., 'torch.nn.CrossEntropyLoss')
+ Type, # replacement object (e.g., FastCrossEntropyLoss)
+ Optional[
+ str
+ ], # path to module to be reloaded (e.g., transformers.models.llama.modeling_llama)
+ ] = None
+
+ def __post_init__(self):
+ if (
+ self.forward is not None
+ and self.forward_builder is not None
+ and self.import_and_maybe_reload is not None
+ ):
+ raise ValueError(
+ f"Rule '{self.rule_id}' must only have only one of forward, "
+ "foward builder, or import_and_maybe_reload, specified."
+ )
+
+ if self.import_and_maybe_reload is not None and self.trigger is not None:
+ raise ValueError(
+ f"Rule '{self.rule_id}' has import_and_maybe_reload specified, "
+ "and trigger must be None."
+ )
+
+ if self.forward_builder_args is not None and self.forward_builder is None:
+ raise ValueError(
+ f"Rule '{self.rule_id}' has forward_builder_args but no "
+ "forward_builder."
+ )
+
+
+# helpful to keep a history of all patching that has been done
+@dataclass
+class ModelPatcherHistory:
+ # instance id of the class that was patched
+ instance: int
+
+ # class of the torch.nn.Module that was patched
+ cls: str
+
+ # parent class of the torch.nn.Module that was patched
+ parent_cls: str
+
+ # module name
+ module_name: str
+
+ # parent
+ parent_module_name: str
+
+ # name of the rule that was applied
+ rule_id: str
+
+
+# singleton class for patching models
+class ModelPatcher:
+
+ # singleton history of patches
+ history: List[ModelPatcherHistory] = []
+
+ # singleton list of rules that have been registered
+ rules: Dict[str, ModelPatcherRule] = {}
+
+ @staticmethod
+ def load_patches(module_names: List[str], reload: bool = False):
+ # each patch should be in a module that calls
+ # ModelPatcher.register. So these will search
+ # and load all the modules it can find
+
+ # reload will trigger the register in that module
+ for plugin_name in module_names:
+ if importlib.util.find_spec(plugin_name):
+ m = importlib.import_module(plugin_name)
+
+ # attempt a reload of imported patch modules if requested
+ # NOTE: but this is brittle as triggering on instance types is
+ # not robust to reloading
+ if reload:
+ try:
+ importlib.reload(m)
+ except AssertionError:
+ # this is if it was loaded already
+ pass
+
+ @staticmethod
+ def register(rule: ModelPatcherRule):
+ # raise if added rule in duplicity
+ assert (
+ rule.rule_id not in ModelPatcher.rules
+ ), f"patch rule '{rule.rule_id}' already exists"
+
+ ModelPatcher.rules[rule.rule_id] = rule
+
+ @staticmethod
+ def did_rule_trigger(module: torch.nn.Module, module_name: str):
+ for name, rule in ModelPatcher.rules.items():
+
+ # if there is no trigger
+ if rule.trigger is None:
+ continue
+
+ if rule.trigger.is_triggered(module, module_name):
+ return name, rule
+
+ return None, None
+
+ @staticmethod
+ def _import_and_reload(model: torch.nn.Module):
+ # each rule.import_and_maybe_reload is a triple
+ # - path to be patched
+ # - replacement object
+ # - path to be reloaded
+
+ # USE CASE 1:
+ # from a import A # <- want to replace A by A_patched
+ # def func():
+ # obj = A()
+
+ # USE CASE 2:
+ # from a import
+ # def A(): # <- want to replace A by A_patched
+ # ...
+
+ # for 1: requires a reload of the func def.
+ # - the patch of A does not need to be perm
+ # for 2: just requires a patch of a.A.
+ # - the patch of a.A needs to be perm
+ # - once a.A has been patched, 'a' cannot be reloaded
+
+ # so for simplicity:
+ # - only allow a single reload
+ # - this is to allow the reload to happen first
+ # - any forward patches that happen after / before
+ # this import and reload should not be affected
+
+ # (a more advanced version could be considered)
+ # targets that have a reload path as a prefix, then
+ # the reload path happens first
+
+ # this will be the path to the module
+ module_path = model.__module__
+
+ # activate the one time rules (i.e. those with no trigger)
+ _with_reload = []
+ _no_reload = []
+ for rule in ModelPatcher.rules.values():
+ if rule.import_and_maybe_reload is not None:
+ _target, _, _reload = rule.import_and_maybe_reload
+ if _reload and _reload.startswith(module_path):
+ _with_reload.append(rule)
+ elif _target.startswith(module_path):
+ _no_reload.append(rule)
+
+ assert len(_with_reload) <= 1, "cannot have have at most one rule with reload"
+
+ # handle those with reload first
+ for rule in _with_reload + _no_reload:
+ _target, _object, _reload = rule.import_and_maybe_reload
+ _patch_target_module(_target, _object, _reload)
+ ModelPatcher.history.append(
+ ModelPatcherHistory(
+ instance=id(model),
+ cls=model.__class__.__name__,
+ parent_cls="",
+ module_name="",
+ parent_module_name="",
+ rule_id=rule.rule_id,
+ )
+ )
+
+ @staticmethod
+ def _patch_forwards(
+ model: torch.nn.Module,
+ patch_kwargs: Dict = None,
+ visited: Set = None,
+ parent_prefix: str = None,
+ parent_mcn: str = None,
+ ):
+ # NOTE: should we avoid repatching of the forwards
+
+ if patch_kwargs is None:
+ patch_kwargs = {}
+
+ if visited is None:
+ visited = set()
+
+ for name, mod in model.named_modules():
+
+ # some stats
+ mod_id = id(mod)
+ mod_class_name = mod.__class__.__name__
+ name = name.split(".")
+ if len(name) > 2:
+ parent_module_name, module_name = ".".join(name[:-1]), name[-1]
+ parent_mod = model.get_submodule(parent_module_name)
+ parent_mod_class_name = parent_mod.__class__.__name__
+ else:
+ # patching on model itself
+ module_name = name[0]
+ parent_mod_class_name = parent_module_name = ""
+ if parent_prefix is not None:
+ parent_module_name = parent_prefix + "." + parent_module_name
+ if parent_mcn is not None:
+ parent_mod_class_name = parent_mcn
+
+ rule_id, rule = ModelPatcher.did_rule_trigger(mod, module_name)
+ if rule_id is None:
+ continue
+
+ # otherwise triggered
+ if rule.forward is not None:
+ forward = rule.forward
+ else:
+ fba = {}
+ if rule.forward_builder_args is not None:
+ fba = {
+ k: w
+ for k, w in patch_kwargs.items()
+ if rule.forward_builder_args
+ }
+ forward = rule.forward_builder(mod, **fba)
+
+ if isinstance(forward, list):
+ # this will be list of tuples case
+
+ # will descend down but
+ # - clear old rules
+ # - replace new rules
+ old_rules = ModelPatcher.rules
+ ModelPatcher.rules = {}
+ for i, (trig, forw) in enumerate(forward):
+ ModelPatcher.register(
+ ModelPatcherRule(
+ rule_id=f"{rule_id}-{i+1}",
+ trigger=trig,
+ forward=forw,
+ )
+ )
+
+ # this is an isolated patch
+ ModelPatcher.patch(
+ mod,
+ patch_kwargs=patch_kwargs,
+ visited=visited,
+ parent_prefix=parent_module_name,
+ parent_mcn=parent_mod_class_name,
+ )
+
+ # replace the rules
+ ModelPatcher.rules = old_rules
+
+ # done
+ continue
+
+ # otherwise
+ mod.forward = MethodType(forward, mod)
+ ModelPatcher.history.append(
+ ModelPatcherHistory(
+ instance=mod_id,
+ cls=mod_class_name,
+ parent_cls=parent_mod_class_name,
+ module_name=module_name,
+ parent_module_name=parent_module_name,
+ rule_id=rule_id,
+ )
+ )
+
+ @staticmethod
+ def patch(model: torch.nn.Module, **kwargs):
+ # NOTE: for a set of rules, this patch function should be called
+ # only once. We do not have any checks for this at the moment
+ try:
+ ModelPatcher._import_and_reload(model.get_base_model())
+ except AttributeError:
+ ModelPatcher._import_and_reload(model)
+
+ # this will patch the forwards
+ ModelPatcher._patch_forwards(model, patch_kwargs=kwargs)
+
+ @staticmethod
+ def summary(raw: bool = False):
+ df = pd.DataFrame([asdict(entry) for entry in ModelPatcher.history])
+ if raw:
+ return df
+
+ if len(df) == 0:
+ return ""
+
+ # summarize and return string
+ df = (
+ df.groupby(["rule_id", "module_name", "cls"])["instance"]
+ .count()
+ .reset_index()
+ )
+ result = []
+ result.append("***************** Module Forwards Patching *************")
+ for x in df.to_dict("records"):
+ result.append(
+ "Rule: {0:15s} Module: {1:25s} Class: {2:15s} Num: {3:2d}".format(
+ x["rule_id"], x["module_name"], x["cls"], x["instance"]
+ )
+ )
+
+ return "\n".join(result)
+
+
+# ------------------------ function -----------------------
+
+
+def patch_model(model: torch.nn.Module, **kwargs):
+ ModelPatcher.patch(model, **kwargs)
+ return model
+
+
+def patch_model_summary():
+ return ModelPatcher.summary()
+
+
+def combine_triggers(*triggers: ModelPatcherTrigger, logic: str = "OR"):
+ assert logic == "OR", "only OR logic implemented for combining triggers"
+
+ # NOTE: this can be probably simplified
+ def _or_logic(*args, **kwargs):
+ for trig in triggers:
+ if trig.check(*args, **kwargs):
+ return True
+ return False
+
+ return ModelPatcherTrigger(check=_or_logic)
+
+
+def combine_functions(*funcs: Callable, logic: str = "APPEND"):
+ assert logic == "APPEND", "only APPEND logic implemented for combining functions"
+
+ def _append(*args, **kwargs):
+ results = []
+ for f in funcs:
+ results += f(*args, **kwargs)
+ return results
+
+ return _append
diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py
new file mode 100644
index 00000000..10819fc0
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/utils.py
@@ -0,0 +1,198 @@
+# Standard
+from functools import partial
+from typing import Callable, List, Type
+import os
+
+# Third Party
+import torch
+
+# Local
+# NOTE: the default activation is swiglu in both cases
+from ..fused_ops.unsloth_lora.bnb.fast_lora import (
+ apply_lora_mlp_swiglu as fused_op_mlp_bnb,
+)
+from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_o_v2 as fused_op_o_bnb
+from ..fused_ops.unsloth_lora.bnb.fast_lora import apply_lora_qkv as fused_op_qkv_bnb
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_mlp as fused_op_mlp_gptq
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_o_v2 as fused_op_o_gptq
+from ..fused_ops.unsloth_lora.gptq.fast_lora import apply_lora_qkv as fused_op_qkv_gptq
+from .model_patcher import ModelPatcherTrigger
+
+KEY_QKV = "qkv"
+KEY_O = "o"
+KEY_MLP = "mlp"
+
+FUSED_OPS = {
+ "auto_gptq": {
+ KEY_QKV: fused_op_qkv_gptq,
+ KEY_O: fused_op_o_gptq,
+ KEY_MLP: fused_op_mlp_gptq,
+ },
+ "bitsandbytes": {
+ KEY_QKV: fused_op_qkv_bnb,
+ KEY_O: fused_op_o_bnb,
+ KEY_MLP: fused_op_mlp_bnb,
+ },
+}
+
+
+# simple utility function to guess if its lora layer
+def _is_loralayer(module: torch.nn.Module, names: List[str] = None):
+ if names is None:
+ names = ["lora_A", "lora_B", "base_layer"]
+ return all(hasattr(module, x) for x in names)
+
+
+# builds a triple of forward functions, that each can be attached
+# on a series of QKV's, where if the first one is called, will call the
+# fused op
+# NOTE: this is not thread-safe (issue warning?)
+# NOTE: the unsloth fused_operation "apply_lora_qkv" assumes that the
+# modules are called q_proj, k_proj, and v_proj, respectively.
+# the fused operation can be changed, depending on what the base layer is
+# i.e. gptq or bnb
+def _build_fused_forwards(
+ attn: torch.nn.Module,
+ fused_operation: Callable = fused_op_qkv_gptq,
+ submodule_names: List[str] = None,
+):
+ # fused opts expected to produce singular or multiple results
+ # module names must be passed in order of what the fused
+
+ outs = {}
+
+ # the fused operation will be called on first one that passes in the
+ # input X.
+ # - populates the triple Q, K, V
+ # - subsequent calls will be a no-op until ALL Q, K, V get reset to None
+ def _fused_op(X):
+
+ # if all of the outs are not yet populated
+ if all(x not in outs for x in submodule_names):
+ fused_outs = fused_operation(attn, X)
+ try:
+ fused_outs = list(fused_outs) # not sure if this is correct
+ except TypeError:
+ # if fused_outs is not iterable
+ fused_outs = [fused_outs]
+ for n, x in zip(submodule_names, fused_outs):
+ outs[n] = x
+
+ # each of these functions
+ # - calls the fused op
+ # -
+
+ def _forward(self, X, name: str):
+ _fused_op(X)
+ assert (
+ name in outs
+ ), "Fused_op needs to be first reset with sequential calls to each of them"
+ V = outs[name]
+ del outs[name]
+ return V
+
+ return zip(submodule_names, [partial(_forward, name=n) for n in submodule_names])
+
+
+def build_lora_fused_ops(
+ attn: torch.nn.Module,
+ base_type: str = "auto_gptq",
+ submodule_names: List[str] = None,
+ fused_op: str = KEY_QKV,
+):
+
+ assert (
+ len(submodule_names) > 0
+ ), "When building lora fused ops requires more than one submodule."
+
+ if submodule_names is None:
+ submodule_names = ["q_proj", "k_proj", "v_proj"]
+
+ # get the fused op
+ fused_operation = FUSED_OPS[base_type][fused_op]
+
+ # handle casting issues
+ if base_type == "auto_gptq":
+
+ # this is required due to this FSDP fix
+ # https://github.com/foundation-model-stack/fms-acceleration/pull/15
+ try:
+ world_size = torch.distributed.get_world_size()
+ except ValueError:
+ world_size = 1 # pg not init
+
+ if (
+ world_size > 1
+ and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
+ ):
+
+ # guarded import
+ # pylint: disable=import-outside-toplevel,import-error
+ # Third Party
+ from fms_acceleration_peft.autogptq_utils import (
+ PATCH_FOR_FSDP_TRITON_V2,
+ patch_forward_to_view_attributes_before_call,
+ )
+
+ # patch each of the fused ops to view the attributes
+ # back into torch.int32
+ # - if there are multiple submodules, then we assume that
+ # 'fused_operation' will be called on module that has
+ # submodules specified in 'submodule_names'.
+ # - otherwise if there is only a single 'submodule_name', then
+ # assume that 'fused_operation' called on the submodule specified
+ # by 'submodule_name' itself
+ if len(submodule_names) > 1:
+ patched_submodule_names = [n + ".base_layer" for n in submodule_names]
+ else:
+ # otherwise assume calling on the 'submodule_name' itself
+ # so its just the base layer.
+ patched_submodule_names = "base_layer"
+
+ fused_operation = patch_forward_to_view_attributes_before_call(
+ fused_operation,
+ PATCH_FOR_FSDP_TRITON_V2,
+ torch.int32,
+ submodule_names=patched_submodule_names,
+ is_method_forward=False,
+ )
+
+ if fused_op == KEY_QKV:
+ return [
+ (ModelPatcherTrigger(check=_is_loralayer, module_name=name), forward)
+ for name, forward in _build_fused_forwards(
+ attn,
+ fused_operation=fused_operation,
+ submodule_names=submodule_names,
+ )
+ ]
+ if fused_op == KEY_O:
+ # otherwise its just a single op
+ submodule_names = submodule_names[0]
+ return [
+ (
+ ModelPatcherTrigger(check=_is_loralayer, module_name=submodule_names),
+ fused_operation,
+ )
+ ]
+ if fused_op == KEY_MLP:
+ # otherwise just return the fused_op that should be attached at the
+ # top MLP level
+ return fused_operation
+
+ raise NotImplementedError(f"Unknown fused op '{fused_op}'")
+
+
+# trigger if either of the conditions are met
+# 1. qkv all have LoRA adapters for a fused op
+# 2. o has a lora adapter for the fused op
+def trigger_fused_ops(
+ module: torch.nn.Module,
+ attn_cls: Type,
+ submodule_names: List[str],
+):
+
+ # trigger if the module meets the attn class and the submodules
+ # are all loralayers
+ _mods = [getattr(module, x) for x in submodule_names]
+ return isinstance(module, attn_cls) and all(_is_loralayer(x) for x in _mods)
diff --git a/plugins/fused-ops-and-kernels/tests/__init__.py b/plugins/fused-ops-and-kernels/tests/__init__.py
new file mode 100644
index 00000000..38a9531e
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright The FMS HF Tuning Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py
new file mode 100644
index 00000000..dd7b472d
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/tests/test_foak_plugins.py
@@ -0,0 +1,84 @@
+# Copyright The IBM Tuning Team
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# SPDX-License-Identifier: Apache-2.0
+# https://spdx.dev/learn/handling-license-info/
+
+# Standard
+import os
+
+# Third Party
+from fms_acceleration import AccelerationPluginConfigError
+from fms_acceleration.utils import (
+ instantiate_framework,
+ read_configuration,
+ update_configuration_contents,
+)
+import pytest # pylint: disable=import-error
+
+# instantiate_fromwork will handle registering and activating AutoGPTQAccelerationPlugin
+
+# configuration
+DIRNAME = os.path.dirname(__file__)
+CONFIG_PATH_AUTO_GPTQ_FOAK = os.path.join(
+ DIRNAME, "../configs/fast_quantized_peft.yaml"
+)
+
+
+def test_configure_gptq_foak_plugin():
+ "test foak plugin loads correctly"
+
+ # test that provided configuration correct correct instantiates plugin
+ with instantiate_framework(
+ read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK), require_packages_check=False
+ ) as framework:
+
+ # check flags and callbacks
+ assert framework.requires_custom_loading is False
+ assert framework.requires_agumentation
+ assert len(framework.get_callbacks_and_ready_for_train()) == 0
+
+ # attempt to activate plugin with configuration pointing to wrong path
+ # - raise with message that no plugins can be configured
+ with pytest.raises(ValueError) as e:
+ with instantiate_framework(
+ update_configuration_contents(
+ read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK),
+ "peft.quantization.fused_ops_and_kernels",
+ "something",
+ ),
+ ):
+ pass
+
+ e.match("No plugins could be configured")
+
+ # NOTE: currently only have all-or-one until address the generic patching
+ # rules
+ # attempt to actiavte plugin with unsupported settings
+ # - raise with appropriate message complaining about wrong setting
+ for key, wrong_value in [
+ ("peft.quantization.fused_ops_and_kernels.fused_lora", False),
+ ("peft.quantization.fused_ops_and_kernels.fast_loss", False),
+ ("peft.quantization.fused_ops_and_kernels.fast_rsm_layernorm", False),
+ ("peft.quantization.fused_ops_and_kernels.fast_rope_embeddings", False),
+ ]:
+ with pytest.raises(AccelerationPluginConfigError) as e:
+ with instantiate_framework(
+ update_configuration_contents(
+ read_configuration(CONFIG_PATH_AUTO_GPTQ_FOAK), key, wrong_value
+ ),
+ ):
+ pass
+
+ e.match(f"FastQuantizedPeftAccelerationPlugin: Value at '{key}'")
diff --git a/plugins/fused-ops-and-kernels/tox.ini b/plugins/fused-ops-and-kernels/tox.ini
new file mode 100644
index 00000000..8b6e5930
--- /dev/null
+++ b/plugins/fused-ops-and-kernels/tox.ini
@@ -0,0 +1,42 @@
+[tox]
+envlist = py, lint
+
+[testenv]
+deps =
+ pytest>=7
+ -e {toxinidir}/../framework
+commands = pytest {posargs:tests}
+
+[testenv:lint]
+description = run linters
+deps =
+ -e {toxinidir}/../framework
+ pylint>=2.16.2,<=3.1.0
+commands = pylint src tests
+allowlist_externals = pylint
+
+[testenv:fmt]
+description = format
+skip_install = true
+deps =
+ black>=22.12
+ isort>=5.11
+commands =
+ # exclude the code ported from unsloth
+ black --exclude .*unsloth.* src
+ black --exclude .*unsloth.* tests
+ isort .
+
+# [testenv:build]
+# description = build wheel
+# deps =
+# build
+# commands = python -m build -w
+# skip_install = True
+#
+# [testenv:twinecheck]
+# description = check wheel
+# deps =
+# twine
+# commands = twine check dist/*
+# skip_install = True
diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml
index 8d45bedf..75f7279b 100644
--- a/sample-configurations/CONTENTS.yaml
+++ b/sample-configurations/CONTENTS.yaml
@@ -19,4 +19,16 @@ framework_configs:
- shortname: baseline-peft-bnb
plugins:
- accelerated-peft
- filename: baseline-peft-bnb-nf4-sample-configuration.yaml
\ No newline at end of file
+ filename: baseline-peft-bnb-nf4-sample-configuration.yaml
+
+ - shortname: accelerated-peft-autogptq-foak
+ plugins:
+ - accelerated-peft
+ - fused-ops-and-kernels
+ filename: accelerated-peft-autogptq-foak-sample-configuration.yaml
+
+ - shortname: accelerated-peft-bnb-foak
+ plugins:
+ - accelerated-peft
+ - fused-ops-and-kernels
+ filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
\ No newline at end of file
diff --git a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml
new file mode 100644
index 00000000..1eb38df3
--- /dev/null
+++ b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml
@@ -0,0 +1,44 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # PEFT-related acceleration
+ peft:
+
+ # quantization-releated acceleration
+ # e.g., kernels for quantized base weights
+ quantization:
+
+ # AutoGPTQ quantized base weights.
+ auto_gptq:
+
+ # Kernel to be used for GPTQ linear laeyer
+ # NOTE: Not all kernels are suitable for PEFT training; need to use
+ # kernels that support autograd forward / backward. The best
+ # recommendation at the moment is "triton_v2".
+ kernel: triton_v2
+
+ # If true, then will already expect quantized checkpoint
+ # passed into TrainingArguments.model_name_or_path
+ from_quantized: true
+ fused_ops_and_kernels:
+
+ # load unsloth optimizations for these 4bit base layer weights.
+ # currently only support "auto_gptq" and "bitsandbytes"
+ base_layer: auto_gptq
+
+ # activate various unsloth optimizations
+ # NOTE: currently supports only all-or-nothing.
+
+ # fused kernels for lora linear layers
+ fused_lora: true
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rsm_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
new file mode 100644
index 00000000..fcb9bb14
--- /dev/null
+++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml
@@ -0,0 +1,44 @@
+# FMS Acceleration Plugin Configuration.
+#
+# Each stanza incorporates various configurations for
+# different fine-tuning / training tasks.
+plugins:
+ # PEFT-related acceleration
+ peft:
+
+ # quantization-releated acceleration
+ # e.g., kernels for quantized base weights
+ quantization:
+
+ # For loading BitsAndBytes quantized layers
+ # to serve as 4bit base-weights for LoRA PEFT-tuning.
+ # NOTE: currently AutoGPTQ is not properly integrated into huggingface /
+ # bitsandbytes, thus recommended quant_type to be either "nf4"
+ # or "fp4".
+ # bitsandbytes:
+ bitsandbytes:
+ quant_type: nf4
+
+ # If True, then no get_peft_model and prepare_model_for_kbit_training
+ # will be called.
+ no_peft_model: false
+ fused_ops_and_kernels:
+
+ # load unsloth optimizations for these 4bit base layer weights.
+ # currently only support "auto_gptq" and "bitsandbytes"
+ base_layer: bitsandbytes
+
+ # activate various unsloth optimizations
+ # NOTE: currently supports only all-or-nothing.
+
+ # fused kernels for lora linear layers
+ fused_lora: true
+
+ # fast loss triton kernels
+ fast_loss: true
+
+ # fast rms norm triton kernels
+ fast_rsm_layernorm: true
+
+ # fast RoPE embedding triton kernels
+ fast_rope_embeddings: true
diff --git a/scripts/benchmarks/README.md b/scripts/benchmarks/README.md
index 115719b7..269d3ead 100644
--- a/scripts/benchmarks/README.md
+++ b/scripts/benchmarks/README.md
@@ -164,6 +164,7 @@ We currently compute the memory values in the report by taking the largest of su
For allocated memory value
```
max([
+ stage0_mem,
stage0_mem + stage1_allocated_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta,
...
@@ -173,13 +174,13 @@ max([
For peak memory value
```
max([
+ stage0_mem,
stage0_mem + stage1_allocated_delta + stage1_peaked_delta,
stage0_mem + stage1_allocated_delta + stage2_allocated_delta + stage2_peaked_delta,
...
])
```
-Notice that we do not include `stage0_mem` alone when computing the max value. This is to avoid misleading comparisons between GPTQ-LoRA and others. GPTQ-LoRA + FSDP currently does not support low-memory mode as mentioned [here](https://github.com/foundation-model-stack/fms-acceleration/issues/18). The `stage0_mem` value of GPTQ-LoRA + FSDP will reflect a larger than expected value as it is loaded fully before the trainer is initialized and then subsequently will be sharded internally in `trainer.prepare`. This might cause some misleading comparisons when other variants are loaded in low-memory mode and have smaller `stage0_mem` memory consumption than GPTQ-LoRA + FSDP. Once low-memory mode is supported for GPTQ-LoRA, we will include `stage0_mem` back inside the max computation
We compare memory values between Nvidia-SMI and Torch in this PR - [Memory Benchmarking](https://github.com/foundation-model-stack/fms-acceleration/pull/14).
diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py
index afbf61cf..ec601c43 100644
--- a/scripts/benchmarks/benchmark.py
+++ b/scripts/benchmarks/benchmark.py
@@ -1,5 +1,6 @@
# Standard
from itertools import product
+from time import sleep
from typing import Any, Callable, Dict, List, Tuple, Union
import argparse
import json
@@ -77,7 +78,7 @@
GPU_LOG_USED_MEM_COLUMN_NAME = "memory.used [MiB]"
GPU_LOG_METRIC_SUFFIX = " MiB"
GPU_TABLE = "timestamp,name,index,memory.used"
-RESULT_FIELD_RESERVED_GPU_MEM = "nvidia_mem_reserved"
+RESULT_FIELD_RESERVED_GPU_MEM = "mem_nvidia_mem_reserved"
RESULT_FIELD_DEVICE_NAME = "gpu_device_name"
HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT = "before_init_mem_gpu"
@@ -86,8 +87,9 @@
KEYWORD_PEAKED_DELTA = "peaked_delta"
KEYWORD_ALLOC_DELTA = "alloc_delta"
HF_ARG_SKIP_MEMORY_METRIC = "--skip_memory_metrics"
-RESULT_FIELD_ALLOCATED_GPU_MEM = "torch_mem_alloc_in_bytes"
-RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "peak_torch_mem_alloc_in_bytes"
+RESULT_FIELD_ALLOCATED_GPU_MEM = "mem_torch_mem_alloc_in_bytes"
+RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM = "mem_peak_torch_mem_alloc_in_bytes"
+ERROR_MESSAGES = "error_messages"
def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
@@ -112,8 +114,8 @@ def extract_gpu_memory_metrics(output_metrics) -> Tuple[float]:
return 0, 0
trainer_stage_order = [
- (HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, False),
- (HF_TRAINER_LOG_GPU_STAGE_INIT, False),
+ (HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT, True),
+ (HF_TRAINER_LOG_GPU_STAGE_INIT, True),
(HF_TRAINER_LOG_GPU_STAGE_TRAIN, True),
]
alloc_running_sum = 0
@@ -357,6 +359,17 @@ def __init__(
self.results_filename = os.path.join(self.save_dir, FILE_RESULTS)
self.gpu_log_filename = os.path.join(self.save_dir, FILE_MEM)
+ @property
+ def is_completed(self):
+ if not os.path.exists(self.results_filename):
+ return False
+ # otherwise open it and check for errors
+ with open(self.results_filename) as f:
+ results = json.load(f)
+
+ # return complete only if no errors
+ return not ERROR_MESSAGES in results
+
def run(
self,
run_cmd: str,
@@ -480,38 +493,6 @@ def maybe_get_experiment_error_traceback(self):
return None if len(results) == 0 else results
- def get_peak_mem_usage_by_device_id(self):
- """
- This function retrieves the raw measurements of reserved GPU memory per device across the experiment -
- computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading).
- Returns:
- - pd.Series of peak memory usage per device id
- - the device name as string - e.g. "NVIDIA A100-SXM4-80GB"
-
- Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series:
-
- - pd.Series
- index
- 0 52729.0
- 1 52783.0
- Name: memory.used [MiB], dtype: float64
- """
-
- # group the gpu readings into device ids
- gpu_logs = pd.read_csv(self.gpu_log_filename, skipinitialspace=True)
- # assume that all the devices have the same device name
- device_name = gpu_logs.name.iloc[-1]
- # extract and convert the gpu memory usage as float values
- gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[
- GPU_LOG_USED_MEM_COLUMN_NAME
- ].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, "")))
- mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME]
- # Calibrate values by subtracting out the initial values of the GPU readings
- # to ensure no existing memory is counted in addition with the experiment
- initial_values = mem_usage_by_device_id.first()
- peak_values = mem_usage_by_device_id.max()
- return peak_values.sub(initial_values), device_name
-
def write_result(self):
"Function to write a json result file"
@@ -519,30 +500,6 @@ def write_result(self):
save_result = ConfigUtils.convert_args_to_dict(self.experiment_args_str)
save_result["num_gpus"] = self.num_gpus
- # if a gpu log file exist, process the raw nvidia logs and write to result
- if os.path.isfile(self.gpu_log_filename):
- # Add GPU info and measurements into the result saving
- peak_mem_usage_by_device_id, device_name = (
- self.get_peak_mem_usage_by_device_id()
- )
- save_result[RESULT_FIELD_DEVICE_NAME] = device_name
- # Memory usage is averaged across all devices in the final result
- save_result[RESULT_FIELD_RESERVED_GPU_MEM] = (
- peak_mem_usage_by_device_id.mean()
- )
-
- # process gpu mem from output metrics and write to result
- # check if HF_ARG_SKIP_MEMORY_METRIC is set to False in experiment arg
- # this arg is specified explicitly inside `def generate_list_of_experiments``
- argument_idx = self.experiment_arg.index(HF_ARG_SKIP_MEMORY_METRIC)
- write_memory_metric = not self.experiment_arg[argument_idx + 1]
- if write_memory_metric:
- peak_gpu_mem, gpu_allocated_mem = extract_gpu_memory_metrics(
- self.get_experiment_final_metrics()
- )
- save_result[RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM] = peak_gpu_mem
- save_result[RESULT_FIELD_ALLOCATED_GPU_MEM] = gpu_allocated_mem
-
# if there is an error we save the error message else we save the final result
maybe_error_messages = self.maybe_get_experiment_error_traceback()
if maybe_error_messages is None:
@@ -552,7 +509,7 @@ def write_result(self):
**self.get_experiment_final_metrics(),
}
else:
- other_results = {"error_messages": maybe_error_messages}
+ other_results = {ERROR_MESSAGES: maybe_error_messages}
# combine the final thing
save_result = {**save_result, **other_results}
@@ -582,7 +539,7 @@ class DryRunExperiment(Experiment):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- def run(self, run_cmd: str, environment_variables: Dict = None):
+ def run(self, run_cmd: str, environment_variables: Dict = None, **kwargs):
def _dummy(*args, **kwargs):
pass
@@ -600,6 +557,37 @@ def maybe_get_experiment_error_traceback(self):
return None
+def get_peak_mem_usage_by_device_id(gpu_logs: pd.DataFrame):
+ """
+ This function retrieves the raw measurements of reserved GPU memory per device across the experiment -
+ computing the peak value for each gpu and then performing a simple calibration (subtracts peak values by the first reading).
+ Returns:
+ - pd.Series of peak memory usage per device id
+ - the device name as string - e.g. "NVIDIA A100-SXM4-80GB"
+
+ Example: For 2 devices with GPU Indices 0,1 - it will return the max measurement value (in MiB) of each device as a Series:
+
+ - pd.Series
+ index
+ 0 52729.0
+ 1 52783.0
+ Name: memory.used [MiB], dtype: float64
+ """
+
+ # assume that all the devices have the same device name
+ device_name = gpu_logs.name.iloc[-1]
+ # extract and convert the gpu memory usage as float values
+ gpu_logs[GPU_LOG_USED_MEM_COLUMN_NAME] = gpu_logs[
+ GPU_LOG_USED_MEM_COLUMN_NAME
+ ].apply(lambda x: float(x.replace(GPU_LOG_METRIC_SUFFIX, "")))
+ mem_usage_by_device_id = gpu_logs.groupby("index")[GPU_LOG_USED_MEM_COLUMN_NAME]
+ # Calibrate values by subtracting out the initial values of the GPU readings
+ # to ensure no existing memory is counted in addition with the experiment
+ initial_values = mem_usage_by_device_id.first()
+ peak_values = mem_usage_by_device_id.max()
+ return peak_values.sub(initial_values), device_name
+
+
def prepare_arguments(args):
defaults = ConfigUtils.read_yaml(args.defaults_config_path)
defaults["training_data_path"] = args.dataset_save_path
@@ -699,6 +687,8 @@ def _gather(rdir):
x for x in os.listdir(rdir) if x.startswith(DIR_PREFIX_EXPERIMENT)
]
for tag in exper_dirs:
+ gpu_log_filename = os.path.join(rdir, tag, FILE_MEM)
+
try:
with open(os.path.join(rdir, tag, FILE_RESULTS)) as f:
tag = tag.replace(DIR_PREFIX_EXPERIMENT + "_", "")
@@ -706,6 +696,42 @@ def _gather(rdir):
experiment_stats[tag] = json.load(f)
except FileNotFoundError:
pass
+
+ if script_args["log_nvidia_smi"]:
+ gpu_logs = pd.read_csv(gpu_log_filename, skipinitialspace=True)
+ peak_nvidia_mem_by_device_id, device_name = (
+ get_peak_mem_usage_by_device_id(gpu_logs)
+ )
+ experiment_stats[tag].update(
+ {
+ # Report the mean peak memory across all gpu device ids
+ RESULT_FIELD_RESERVED_GPU_MEM: peak_nvidia_mem_by_device_id.mean(),
+ RESULT_FIELD_DEVICE_NAME: device_name,
+ }
+ )
+
+ if script_args["log_memory_hf"] and tag in experiment_stats.keys():
+ memory_metrics_prefixes = [
+ HF_TRAINER_LOG_GPU_STAGE_BEFORE_INIT,
+ HF_TRAINER_LOG_GPU_STAGE_INIT,
+ HF_TRAINER_LOG_GPU_STAGE_TRAIN,
+ ]
+ memory_metrics = {
+ k: v
+ for k, v in experiment_stats[tag].items()
+ if any([prefix in k for prefix in memory_metrics_prefixes])
+ }
+ if len(memory_metrics.keys()) > 0:
+ peak_torch_gpu_mem, torch_gpu_mem = extract_gpu_memory_metrics(
+ memory_metrics
+ )
+ experiment_stats[tag].update(
+ {
+ RESULT_FIELD_PEAK_ALLOCATED_GPU_MEM: peak_torch_gpu_mem,
+ RESULT_FIELD_ALLOCATED_GPU_MEM: torch_gpu_mem,
+ }
+ )
+
df = pd.DataFrame.from_dict(experiment_stats, orient="index").sort_index()
try:
df["framework_config"] = df["acceleration_framework_config_file"].map(
@@ -781,6 +807,14 @@ def main(args):
log_memory_in_trainer=args.log_memory_hf,
)
):
+ # store pointer to file for future result retrival
+ experiment_stats[experiment.tag] = experiment.results_filename
+
+ if experiment.is_completed:
+ # if completed, dont proceed
+ sleep(0.1) # sleep a bit to allow the tqdm to update
+ continue
+
if experiment.num_gpus > 1:
prefix = COMMAND_ACCELERATE.format(
accelerate_config_path=args.accelerate_config,
@@ -806,10 +840,9 @@ def main(args):
log_nvidia_smi=args.log_nvidia_smi,
)
- # write results and store pointers to files
+ # write results
experiment.write_result()
experiment.write_shell_command()
- experiment_stats[experiment.tag] = experiment.results_filename
# 4. Consolidates the experiment results into a summary
for tag, path in experiment_stats.items():
diff --git a/scripts/benchmarks/display_bench_results.py b/scripts/benchmarks/display_bench_results.py
index b590f26c..51ba5642 100644
--- a/scripts/benchmarks/display_bench_results.py
+++ b/scripts/benchmarks/display_bench_results.py
@@ -1,18 +1,21 @@
# Standard
+from typing import List
import argparse
# First Party
# import this because of alot of internal contants
-from scripts.benchmarks.benchmark import gather_report, DIR_SAMP_CONFIGS
-from typing import List
+from scripts.benchmarks.benchmark import DIR_SAMP_CONFIGS, gather_report
-def main(*directories: str, output_filename: str = "results.csv", remove_columns: List[str] = None):
+
+def main(
+ *directories: str,
+ output_filename: str = "results.csv",
+ remove_columns: List[str] = None,
+ keep_columns: List[str] = None,
+):
"gather outputs from a list of directories and output to a csv"
- df, constant = gather_report(*directories, raw=False)
- # filter result columns to keep by the inverse of remove_columns
- if remove_columns:
- df = df[df.columns[~df.columns.isin(remove_columns)]]
+ df, constant = gather_report(directories, raw=False)
errors = []
try:
@@ -22,12 +25,25 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns
df = df.loc[df.error_messages.isna()]
except:
pass
- df = df.reset_index().drop("output_dir", axis=1)
+
+ # filter result columns to keep by the inverse of remove_columns
+ if remove_columns:
+ df = df[df.columns[~df.columns.isin(remove_columns)]]
+
+ # assume keep and remove are disjoint
+ kept = 0
+ if keep_columns:
+ for c in keep_columns:
+ if c in constant:
+ df[c] = constant[c]
+ kept += 1
+
+ df = df.reset_index(drop=True).drop("output_dir", axis=1)
df.reindex(sorted(df.columns), axis=1).to_csv(output_filename, index=False)
print("***************** Report Created ******************")
print(f"Total lines: '{len(df)}'")
print(f"Number columns included: '{len(df.columns)}'")
- print(f"Number columns excluded: '{len(constant)}'")
+ print(f"Number columns excluded: '{len(constant)-kept}'")
print(f"Excluding number of exceptions caught: '{len(errors)}'")
print(f"Written report to '{output_filename}'")
@@ -53,6 +69,16 @@ def main(*directories: str, output_filename: str = "results.csv", remove_columns
nargs="*",
help="list of columns to ignore from results.csv",
)
+ parser.add_argument(
+ "--keep_columns",
+ nargs="*",
+ help="list of columns to always include into results.csv",
+ )
args = parser.parse_args()
- main(args.bench_outputs, output_filename=args.result_file, remove_columns=args.remove_columns)
+ main(
+ *args.bench_outputs,
+ output_filename=args.result_file,
+ remove_columns=args.remove_columns,
+ keep_columns=args.keep_columns,
+ )
diff --git a/scripts/benchmarks/refs/a100_80gb.csv b/scripts/benchmarks/refs/a100_80gb.csv
index 4434d864..45cdf125 100644
--- a/scripts/benchmarks/refs/a100_80gb.csv
+++ b/scripts/benchmarks/refs/a100_80gb.csv
@@ -1,61 +1,85 @@
-epoch,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,nvidia_mem_reserved,peak_torch_mem_alloc_in_bytes,peft_method,per_device_train_batch_size,r,target_modules,torch_mem_alloc_in_bytes,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
-0.04,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,77705.0,72971724288.0,,4,,,44004763136.0,0.9278398831685384,177.1092,0.678,0.169,2775.237
-0.04,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,44706.0,36762859520.0,,2,,,29521119232.0,0.8970902442932129,91.086,1.317,0.329,2698.11
-0.09,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,74383.0,72972117504.0,,8,,,44005156352.0,0.9879656155904134,322.458,0.744,0.093,3048.583
-0.09,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,53907.0,36763056128.0,,4,,,29521315840.0,0.9259945551554362,167.7727,1.431,0.179,2929.678
-,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,4,,,,,,,,
-,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79353.0,,,2,,,,,,,,
-,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,81043.0,,,8,,,,,,,,
-,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,79827.0,,,4,,,,,,,,
-,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,4,,,,,,,,
-,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,80830.0,,,2,,,,,,,,
-,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,80837.0,,,8,,,,,,,,
-,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,80834.5,,,4,,,,,,,,
-0.04,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,29731.0,26108963328.0,lora,4,16,q_proj k_proj v_proj o_proj,15119590912.0,0.9096682230631511,136.624,0.878,0.22,3597.611
-0.04,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,18697.0,15123161088.0,lora,2,16,q_proj k_proj v_proj o_proj,7850391552.0,0.8918854713439941,82.0311,1.463,0.366,2995.936
-0.09,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,43195.0,37098695168.0,lora,8,16,q_proj k_proj v_proj o_proj,15119984128.0,0.962119706471761,270.6301,0.887,0.111,3632.412
-0.09,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,26235.0,21433753600.0,lora,4,16,q_proj k_proj v_proj o_proj,7850588160.0,0.9218235015869141,143.8184,1.669,0.209,3417.643
-,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,62617.0,57540387840.0,lora,2,16,q_proj k_proj v_proj o_proj,47311452160.0,0.9361546834309896,179.3128,0.669,0.167,1370.566
-,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,80955.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-0.09,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,69848.0,64347637760.0,lora,4,16,q_proj k_proj v_proj o_proj,47311648768.0,0.9383139928181966,280.8919,0.854,0.107,1749.855
-,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80894.0,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80917.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80979.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,baseline-peft-bnb,24,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,27023.0,22825932800.0,lora,4,16,q_proj k_proj v_proj o_proj,5368221184.0,0.9589527130126954,178.8061,0.671,0.168,2748.9
-0.04,True,baseline-peft-bnb,25,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13530.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9154380798339844,87.3652,1.374,0.343,2813.02
-0.09,True,baseline-peft-bnb,26,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,47145.0,40278956032.0,lora,8,16,q_proj k_proj v_proj o_proj,5368614400.0,0.9702634493509928,341.2286,0.703,0.088,2880.884
-0.09,True,baseline-peft-bnb,27,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21502.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.914565912882487,149.9341,1.601,0.2,3278.241
-0.04,True,baseline-peft-bnb,28,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,48313.0,46419968512.0,lora,4,16,q_proj k_proj v_proj o_proj,25726225920.0,0.9744932492574055,351.8623,0.341,0.085,1396.91
-0.04,True,baseline-peft-bnb,29,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25549.0,21922782720.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9303209940592448,171.4299,0.7,0.175,1433.589
-0.09,True,baseline-peft-bnb,30,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,69931.0,67089150464.0,lora,8,16,q_proj k_proj v_proj o_proj,25726619136.0,0.9745417594909668,629.837,0.381,0.048,1560.785
-0.09,True,baseline-peft-bnb,31,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29384115200.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9310146331787109,300.5119,0.799,0.1,1635.609
-,True,baseline-peft-bnb,32,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80893.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,baseline-peft-bnb,33,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52634.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.0399916648864747,584.3145,0.205,0.051,420.595
-,True,baseline-peft-bnb,34,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,79557.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,True,baseline-peft-bnb,35,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80749.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,accelerated-peft-bnb,36,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,19931.0,15860019712.0,lora,4,16,q_proj k_proj v_proj o_proj,4843384320.0,0.9652111371358235,143.3569,0.837,0.209,3428.645
-0.04,True,accelerated-peft-bnb,37,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,13497.0,9974622720.0,lora,2,16,q_proj k_proj v_proj o_proj,2727018496.0,0.9277165730794271,86.4307,1.388,0.347,2843.435
-0.09,True,accelerated-peft-bnb,38,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,34355.0,26849751552.0,lora,8,16,q_proj k_proj v_proj o_proj,4843777536.0,0.9493892669677735,279.7156,0.858,0.107,3514.427
-0.09,True,accelerated-peft-bnb,39,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,21479.0,16587205120.0,lora,4,16,q_proj k_proj v_proj o_proj,2727215104.0,0.9110882759094239,149.3914,1.607,0.201,3290.15
-0.04,True,accelerated-peft-bnb,40,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,38405.0,36218024448.0,lora,4,16,q_proj k_proj v_proj o_proj,25201389056.0,0.9741149584452311,278.5888,0.431,0.108,1764.32
-0.04,True,accelerated-peft-bnb,41,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,25592.0,21906697728.0,lora,2,16,q_proj k_proj v_proj o_proj,13219233792.0,0.9300654411315918,172.7359,0.695,0.174,1422.75
-0.09,True,accelerated-peft-bnb,42,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,50875.0,47207756288.0,lora,8,16,q_proj k_proj v_proj o_proj,25201782272.0,0.9748441060384114,512.2298,0.469,0.059,1919.139
-0.09,True,accelerated-peft-bnb,43,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,32957.0,29369087488.0,lora,4,16,q_proj k_proj v_proj o_proj,13219430400.0,0.9301350593566895,287.6381,0.834,0.104,1708.814
-0.04,True,accelerated-peft-bnb,44,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,72829.0,68159977472.0,lora,4,16,q_proj k_proj v_proj o_proj,37346815488.0,1.118430455525716,1075.2044,0.112,0.028,457.141
-0.04,True,accelerated-peft-bnb,45,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,52632.0,46524471808.0,lora,2,16,q_proj k_proj v_proj o_proj,19172741120.0,1.040946865081787,586.651,0.205,0.051,418.92
-,True,accelerated-peft-bnb,46,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,80405.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-,True,accelerated-peft-bnb,47,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,80954.0,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,
-0.04,True,accelerated-peft-autogptq,48,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,20453.0,15890329088.0,lora,4,16,q_proj k_proj v_proj o_proj,4873693696.0,1.3805528958638509,151.0359,0.795,0.199,3254.326
-0.04,True,accelerated-peft-autogptq,49,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,17198.0,9952175616.0,lora,2,16,q_proj k_proj v_proj o_proj,3005709312.0,1.1706618309020995,87.4109,1.373,0.343,2811.548
-0.09,True,accelerated-peft-autogptq,50,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,34247.0,26880060928.0,lora,8,16,q_proj k_proj v_proj o_proj,4874086912.0,1.2741642634073893,282.6391,0.849,0.106,3478.076
-0.09,True,accelerated-peft-autogptq,51,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,24783.0,16262768128.0,lora,4,16,q_proj k_proj v_proj o_proj,3005905920.0,1.043952751159668,152.5473,1.573,0.197,3222.083
-0.04,True,accelerated-peft-autogptq,52,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,37461.0,35528093184.0,lora,4,16,q_proj k_proj v_proj o_proj,24511457792.0,0.9936613400777181,263.6066,0.455,0.114,1864.597
-0.04,True,accelerated-peft-autogptq,53,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,46641.0,25708175360.0,lora,2,16,q_proj k_proj v_proj o_proj,12788874240.0,0.9420519828796386,167.065,0.718,0.18,1471.045
-0.09,True,accelerated-peft-autogptq,54,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,49925.0,46517825024.0,lora,8,16,q_proj k_proj v_proj o_proj,24511851008.0,0.9855653127034505,498.9022,0.481,0.06,1970.406
-0.09,True,accelerated-peft-autogptq,55,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,52358.0,27739090432.0,lora,4,16,q_proj k_proj v_proj o_proj,12789070848.0,0.9389812151590983,281.8034,0.852,0.106,1744.195
-0.04,True,accelerated-peft-autogptq,56,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,71565.0,65895347200.0,lora,4,16,q_proj k_proj v_proj o_proj,36290144768.0,1.0755928039550782,1060.8387,0.113,0.028,463.331
-0.04,True,accelerated-peft-autogptq,57,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80387.0,45397678592.0,lora,2,16,q_proj k_proj v_proj o_proj,18649885696.0,1.0256956418355305,576.0422,0.208,0.052,426.635
-,True,accelerated-peft-autogptq,58,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,1,80293.0,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,
-0.08,True,accelerated-peft-autogptq,59,2e-4,16,0.0,TheBloke/Llama-2-70B-GPTQ,2,80363.0,70667573760.0,lora,4,16,q_proj k_proj v_proj o_proj,18650082304.0,1.0266701062520345,1089.3291,0.22,0.028,451.214
+epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second
+0.15,,none,2e-5,,,76679.0,72971724288,44004763136,mistralai/Mistral-7B-v0.1,1,,4,,,float16,0.9112484455108643,565.9213,0.707,0.177,2895.102
+0.15,,none,2e-5,,,43702.0,36762859520,29521119232,mistralai/Mistral-7B-v0.1,2,,2,,,float16,0.8622726058959961,307.6782,1.3,0.325,2662.522
+0.29,,none,2e-5,,,70669.0,72972117504,44005156352,mistralai/Mistral-7B-v0.1,1,,8,,,float16,1.017976951599121,1094.9632,0.731,0.091,2992.612
+0.29,,none,2e-5,,,52882.0,36763056128,29521315840,mistralai/Mistral-7B-v0.1,2,,4,,,float16,0.8944576263427735,576.1931,1.388,0.174,2843.491
+,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,float16,,,,,
+,,none,2e-5,,,79169.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,float16,,,,,
+,,none,2e-5,,,80969.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,float16,,,,,
+,,none,2e-5,,,80083.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,float16,,,,,
+,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,4,,,float16,,,,,
+,,none,2e-5,,,80923.0,0,0,NousResearch/Llama-2-70b-hf,2,,2,,,float16,,,,,
+,,none,2e-5,,,80987.0,0,0,NousResearch/Llama-2-70b-hf,1,,8,,,float16,,,,,
+,,none,2e-5,,,81006.0,0,0,NousResearch/Llama-2-70b-hf,2,,4,,,float16,,,,,
+0.15,,none,2e-4,16,0.0,28703.0,26108963328,15119590912,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8818108749389648,458.2667,0.873,0.218,3575.21
+0.15,,none,2e-4,16,0.0,17669.0,15123161088,7850391552,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8540384006500245,270.1999,1.48,0.37,3031.829
+0.29,,none,2e-4,16,0.0,42167.0,37098695168,15119984128,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,1.0028394603729247,912.5081,0.877,0.11,3590.982
+0.29,,none,2e-4,16,0.0,25207.0,21433753600,7850588160,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8833828353881836,482.6901,1.657,0.207,3394.311
+,,none,2e-4,16,0.0,80990.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.15,,none,2e-4,16,0.0,61532.0,57546370048,47311452160,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8696129798889161,561.2483,0.713,0.178,1459.604
+,,none,2e-4,16,0.0,80207.0,0,0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.29,,none,2e-4,16,0.0,69171.0,64398757376,47311648768,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.885084867477417,938.9714,0.852,0.106,1744.888
+,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80907.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80617.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+,,none,2e-4,16,0.0,80783.0,0,0,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,25995.0,22825932800,5368221184,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8698946189880371,586.9178,0.682,0.17,2791.532
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,12476.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8552890300750733,284.376,1.407,0.352,2880.693
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,46117.0,40278956032,5368614400,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8654958820343017,1148.1408,0.697,0.087,2854.005
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,20405.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8869294357299805,503.0597,1.59,0.199,3256.87
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,47189.0,46475660288,25726225920,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8893787956237793,1185.2488,0.337,0.084,1382.326
+0.15,True,baseline-peft-bnb,2e-4,16,0.0,24751.0,21932720128,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8617707204818725,568.5808,0.704,0.176,1440.78
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,68683.0,67165218816,25726619136,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8893123245239258,2124.0668,0.377,0.047,1542.701
+0.29,True,baseline-peft-bnb,2e-4,16,0.0,32064.0,29353074176,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8585504531860352,962.8971,0.831,0.104,1701.532
+,True,baseline-peft-bnb,2e-4,16,0.0,80121.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.14,True,baseline-peft-bnb,2e-4,16,0.0,51701.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9204118633270264,1981.2518,0.202,0.05,413.476
+,True,baseline-peft-bnb,2e-4,16,0.0,79555.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,baseline-peft-bnb,2e-4,16,0.0,80394.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9444941711425782,3760.1788,0.213,0.027,435.724
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,18903.0,15860019712,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8704616069793701,479.6819,0.834,0.208,3415.597
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,12533.0,9974622720,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8528211212158203,282.8845,1.414,0.354,2895.882
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,33327.0,26849751552,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8675907611846924,945.5376,0.846,0.106,3465.542
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,20423.0,16587205120,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.854712610244751,502.3584,1.592,0.199,3261.417
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19257.0,13636909056,4843384320,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8722561931610108,420.8819,0.95,0.238,3892.778
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,12118.0,9796856320,2727018496,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8581914234161377,232.51,1.72,0.43,3523.289
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,32209.0,22430791680,4843777536,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8683128643035889,821.991,0.973,0.122,3986.418
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,19463.0,16207063552,2727215104,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.852388572692871,427.1268,1.873,0.234,3835.864
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,37417.0,36218024448,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8887558174133301,913.0381,0.438,0.11,1794.449
+0.15,True,accelerated-peft-bnb,2e-4,16,0.0,24952.0,21921468928,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8612120914459228,572.3054,0.699,0.175,1431.404
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,49893.0,47207756288,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8909227275848388,1711.7453,0.467,0.058,1914.303
+0.29,True,accelerated-peft-bnb,2e-4,16,0.0,32207.0,29359173632,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8591176319122314,959.9538,0.833,0.104,1706.749
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,37547.0,35651058176,25201389056,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8895366668701172,854.9879,0.468,0.117,1916.284
+0.15,True,accelerated-peft-bnb-foak,2e-4,16,0.0,24572.0,21746056192,13219233792,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8630767631530761,514.5553,0.777,0.194,1592.054
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,49861.0,46058696192,25201782272,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8951810073852539,1601.6113,0.499,0.062,2045.94
+0.29,True,accelerated-peft-bnb-foak,2e-4,16,0.0,31701.0,29043888640,13219430400,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8600863265991211,880.114,0.909,0.114,1861.577
+0.14,True,accelerated-peft-bnb,2e-4,16,0.0,71801.0,68159977472,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9996430969238281,3700.3604,0.108,0.027,442.768
+0.14,True,accelerated-peft-bnb,2e-4,16,0.0,51579.0,46524471808,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9264963436126709,1955.4907,0.205,0.051,418.923
+,True,accelerated-peft-bnb,2e-4,16,0.0,79375.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-bnb,2e-4,16,0.0,80815.0,72398346752,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9262647342681884,3714.7153,0.215,0.027,441.057
+0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,71995.0,67350935552,37346815488,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9998687934875489,3351.04,0.119,0.03,488.923
+0.14,True,accelerated-peft-bnb-foak,2e-4,16,0.0,51141.0,46250760704,19172741120,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9389877033233642,1747.6289,0.229,0.057,468.749
+,True,accelerated-peft-bnb-foak,2e-4,16,0.0,80303.0,0,0,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-bnb-foak,2e-4,16,0.0,79861.0,71720933888,19172937728,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9403298473358155,3375.4111,0.237,0.03,485.393
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,19425.0,15890329088,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.009563512802124,491.6352,0.814,0.203,3332.552
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,12230.0,9690031616,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9266629409790039,294.4237,1.359,0.34,2782.385
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,33219.0,26880060928,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9904310989379883,953.3973,0.839,0.105,3436.972
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,19477.0,16000624128,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8998308277130127,506.1818,1.58,0.198,3236.781
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19065.0,13631990784,4873693696,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.003525791168213,414.297,0.965,0.241,3954.651
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,11879.0,9512265216,2743565312,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9293491744995117,224.6767,1.78,0.445,3646.128
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,32721.0,22390647808,4874086912,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.992929859161377,810.9726,0.986,0.123,4040.581
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,19063.0,15620482560,2743761920,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9045120429992676,418.8226,1.91,0.239,3911.919
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,36389.0,35528093184,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.89991379737854,897.8879,0.445,0.111,1824.727
+0.15,True,accelerated-peft-autogptq,2e-4,16,0.0,22882.0,20691720192,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8638970375061035,557.2929,0.718,0.179,1469.963
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,48959.0,46517825024,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.893577823638916,1673.2594,0.478,0.06,1958.334
+0.29,True,accelerated-peft-autogptq,2e-4,16,0.0,29704.0,27482931712,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.864154224395752,938.3626,0.853,0.107,1746.02
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,36607.0,33649802752,24511457792,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8993340969085694,811.6061,0.493,0.123,2018.713
+0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,22801.0,20438869504,12526730240,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.8660580062866211,478.0288,0.837,0.209,1713.704
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49669.0,42707730944,24511851008,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.8937735366821289,1533.2657,0.522,0.065,2137.138
+0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,29370.0,26951336960,12526926848,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.8651807403564453,838.8338,0.954,0.119,1953.188
+0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,71177.0,65895347200,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9811842250823974,3639.6437,0.11,0.027,450.154
+0.14,True,accelerated-peft-autogptq,2e-4,16,0.0,49475.0,44873390592,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9557892894744873,1923.445,0.208,0.052,425.902
+,True,accelerated-peft-autogptq,2e-4,16,0.0,79265.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-autogptq,2e-4,16,0.0,79187.0,70143285760,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9580207633972168,3685.3642,0.217,0.027,444.569
+0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,71223.0,65086305280,36290144768,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.981500825881958,3273.1958,0.122,0.031,500.551
+0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,49187.0,44599679488,18125597696,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9558010864257812,1682.0158,0.238,0.059,487.035
+,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,80945.0,0,0,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,,
+0.28,True,accelerated-peft-autogptq-foak,2e-4,16,0.0,78208.0,69465872896,18125794304,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9556115436553955,3298.135,0.243,0.03,496.766
diff --git a/scripts/benchmarks/refs/l40_40gb.csv b/scripts/benchmarks/refs/l40_40gb.csv
deleted file mode 100644
index 2158c782..00000000
--- a/scripts/benchmarks/refs/l40_40gb.csv
+++ /dev/null
@@ -1,49 +0,0 @@
-acceleration_framework_config_file,epoch,error_messages,fp16,framework_config,index,learning_rate,lora_alpha,lora_dropout,model_name_or_path,num_gpus,output_dir,peft_method,per_device_train_batch_size,r,target_modules,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second,training_data_path
-,,,,none,0,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,0.03,,,none,1,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,2,,,0.9020393848419189,102.4493,0.781,0.195,1599.23,benchmark_outputs/data/cache.json
-,,,,none,2,2e-5,,,mistralai/Mistral-7B-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,0.06,,,none,3,2e-5,,,mistralai/Mistral-7B-v0.1,2,,,4,,,0.936076545715332,170.7722,0.937,0.117,1918.814,benchmark_outputs/data/cache.json
-,,,,none,4,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,5,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,2,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,6,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,7,2e-5,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,8,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,9,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,2,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,10,2e-5,,,NousResearch/Llama-2-70b-hf,1,,,8,,,,,,,,benchmark_outputs/data/cache.json
-,,,,none,11,2e-5,,,NousResearch/Llama-2-70b-hf,2,,,4,,,,,,,,benchmark_outputs/data/cache.json
-,0.03,,,none,12,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9326287746429444,120.2794,0.665,0.166,2724.324,benchmark_outputs/data/cache.json
-,0.03,,,none,13,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9157441139221192,78.5825,1.018,0.255,2084.943,benchmark_outputs/data/cache.json
-,0.06,,,none,14,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.0113807678222657,241.3246,0.663,0.083,2715.679,benchmark_outputs/data/cache.json
-,0.06,,,none,15,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9433841228485107,133.2158,1.201,0.15,2459.768,benchmark_outputs/data/cache.json
-,,,,none,16,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,17,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,18,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,19,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,20,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,21,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,22,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-,,,,none,23,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,36,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.6183419704437256,137.2634,0.583,0.146,2387.235,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,37,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.7251328945159912,73.906,1.082,0.271,2216.871,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,38,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,1.5904263019561768,272.1958,0.588,0.073,2407.679,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,39,2e-4,16,0.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.515465259552002,138.6152,1.154,0.144,2363.954,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,40,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.012540912628174,227.0536,0.352,0.088,1443.183,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.03,,True,accelerated-peft-autogptq,41,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,1.0235525131225587,121.7118,0.657,0.164,1346.13,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,42,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,0.06,,True,accelerated-peft-autogptq,43,2e-4,16,0.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,1.0152217864990234,229.6679,0.697,0.087,1426.756,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,44,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,45,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,46,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-autogptq-sample-configuration.yaml,,,True,accelerated-peft-autogptq,47,2e-4,16,0.0,TheBloke/Nous-Hermes-Llama2-70B-GPTQ,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,0,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,0.9979345798492432,130.1845,0.615,0.154,2517.044,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,1,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.942676591873169,69.8209,1.146,0.286,2346.575,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,2,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,0.9919514656066895,259.8776,0.616,0.077,2521.802,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,3,2e-4,16,0.0,mistralai/Mistral-7B-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.933735466003418,133.6157,1.197,0.15,2452.406,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,4,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,4,16,q_proj k_proj v_proj o_proj,1.015654945373535,218.3215,0.366,0.092,1500.906,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.03,,True,accelerated-peft-bnb,5,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,2,16,q_proj k_proj v_proj o_proj,0.9546889305114746,173.2373,0.462,0.115,945.755,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,6,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,0.06,,True,accelerated-peft-bnb,7,2e-4,16,0.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,lora,4,16,q_proj k_proj v_proj o_proj,0.9585415840148925,273.4507,0.585,0.073,1198.315,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,8,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,9,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,2,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,10,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,1,,lora,8,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
-sample-configurations/accelerated-peft-bnb-nf4-sample-configuration.yaml,,,True,accelerated-peft-bnb,11,2e-4,16,0.0,NousResearch/Llama-2-70b-hf,2,,lora,4,16,q_proj k_proj v_proj o_proj,,,,,,benchmark_bnb_outputs/data/cache.json
diff --git a/scripts/benchmarks/scenarios.yaml b/scripts/benchmarks/scenarios.yaml
index 248eacb2..42f7c753 100644
--- a/scripts/benchmarks/scenarios.yaml
+++ b/scripts/benchmarks/scenarios.yaml
@@ -52,6 +52,7 @@ scenarios:
- name: accelerated-peft-bnb
framework_config:
- accelerated-peft-bnb
+ - accelerated-peft-bnb-foak
arguments:
fp16: True
learning_rate: 2e-4
@@ -69,6 +70,7 @@ scenarios:
- name: accelerated-peft-gptq
framework_config:
- accelerated-peft-autogptq
+ - accelerated-peft-autogptq-foak
arguments:
learning_rate: 2e-4
fp16: True
@@ -81,4 +83,4 @@ scenarios:
model_name_or_path:
- 'TheBloke/Mistral-7B-v0.1-GPTQ'
- 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ'
- - 'TheBloke/Llama-2-70B-GPTQ'
\ No newline at end of file
+ - 'TheBloke/Llama-2-70B-GPTQ'
diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py
index 67ad4058..b3485e3c 100644
--- a/scripts/generate_sample_configurations.py
+++ b/scripts/generate_sample_configurations.py
@@ -139,9 +139,11 @@ def read_configuration(path: str) -> Dict:
#
# NOTE: an augmentation (path, value) will augment a config at the
# specified key path, with the value.
-KEY_AUTO_GPTQ = "auto_gptq"
+KEY_AUTO_GPTQ = "auto-gptq"
KEY_BNB_NF4 = "bnb-nf4"
KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4"
+KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak"
+KEY_BNB_NF4_FOAK = "bnb-nf4-foak"
CONFIGURATIONS = {
KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml",
@@ -152,10 +154,18 @@ def read_configuration(path: str) -> Dict:
KEY_BNB_NF4_BASELINE: (
"plugins/accelerated-peft/configs/bnb.yaml",
[
- ("peft.quantization.bitsandbytes.quant_type", "nf4"),
- ("peft.quantization.bitsandbytes.no_peft_model", True),
+ ("peft.quantization.bitsandbytes.quant_type", "nf4"),
+ ("peft.quantization.bitsandbytes.no_peft_model", True),
],
),
+ KEY_AUTO_GPTQ_FOAK: (
+ "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
+ [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")],
+ ),
+ KEY_BNB_NF4_FOAK: (
+ "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml",
+ [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")],
+ ),
}
# list of (tag, combi) tuples
@@ -167,6 +177,8 @@ def read_configuration(path: str) -> Dict:
("accelerated-peft-autogptq", (KEY_AUTO_GPTQ,)),
("accelerated-peft-bnb-nf4", (KEY_BNB_NF4,)),
("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)),
+ ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)),
+ ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)),
]
@@ -176,10 +188,15 @@ def merge_configs(config_contents: List[Dict]):
# merge in place
def _merge(result: Dict, new_contents: Dict):
- for k in new_contents:
+ for k, v in new_contents.items():
if k not in result:
- result[k] = {}
- _merge(result[k], new_contents)
+ # if k is not in result, it means v does not
+ # exist as a subtree under result, so we just do
+ # an assingment
+ result[k] = v
+ else:
+ # otherwise we call the merge
+ _merge(result[k], v)
if len(config_contents) == 0:
return {}
diff --git a/scripts/run_benchmarks.sh b/scripts/run_benchmarks.sh
index e08125b3..8f8a1f9b 100644
--- a/scripts/run_benchmarks.sh
+++ b/scripts/run_benchmarks.sh
@@ -38,7 +38,7 @@ PIP_REQUIREMENTS_FILE=requirements.txt
DRY_RUN=${DRY_RUN:-"false"}
NO_DATA_PROCESSING=${NO_DATA_PROCESSING:-"false"}
NO_OVERWRITE=${NO_OVERWRITE:-"false"}
-MEMORY_LOGGING=${MEMORY_LOGGING:-"huggingface"}
+MEMORY_LOGGING=${MEMORY_LOGGING:-"all"}
# inputs
NUM_GPUS_MATRIX=${1-"1 2"}
@@ -58,10 +58,10 @@ if [ -n "$RESULT_DIR" ]; then
echo "Results dir $RESULT_DIR is not empty, but NO_OVERWRITE=true"
echo "If intending to overwrite please delete the folder manually"
echo "or do not set NO_OVERWRITE"
- exit 1
+ else
+ echo "Deleting $RESULT_DIR"
+ rm -rf $RESULT_DIR
fi
- echo "Deleting $RESULT_DIR"
- rm -rf $RESULT_DIR
fi
# tag on the directories
@@ -98,7 +98,11 @@ elif [ "$MEMORY_LOGGING" = "all" ]; then
fi
# dump out the environment
-pip freeze > $PIP_REQUIREMENTS_FILE
+if [ ! "$NO_OVERWRITE" = "true" ]; then
+ echo "Creating $RESULT_DIR"
+ mkdir -p $RESULT_DIR
+ pip freeze > $PIP_REQUIREMENTS_FILE
+fi
# run the bench
python $WORKING_DIR/benchmark.py \
@@ -114,8 +118,10 @@ python $WORKING_DIR/benchmark.py \
# this will write to the BENCH_RESULT_FILE
# Remove the columns with values already represented by other metrics in the summary report
PYTHONPATH=. \
- python $WORKING_DIR/display_bench_results.py benchmark_outputs \
+ python $WORKING_DIR/display_bench_results.py $RESULT_DIR \
--result_file $BENCH_RESULT_FILE \
+ --keep_columns \
+ 'torch_dtype' \
--remove_columns \
'before_init_mem_cpu' \
'before_init_mem_gpu' \
@@ -127,5 +133,7 @@ PYTHONPATH=. \
'train_mem_cpu_peaked_delta' \
'train_mem_gpu_alloc_delta' \
'train_mem_gpu_peaked_delta' \
+ 'training_data_path' \
+ 'error_messages' \
'acceleration_framework_config_file'
diff --git a/tox.ini b/tox.ini
index d719cb3e..e8d8aa92 100644
--- a/tox.ini
+++ b/tox.ini
@@ -36,6 +36,7 @@ commands =
# install the plugins for test
# NOTE: when there are more plugins install here
python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft
+ python -m fms_acceleration.cli install -e {toxinidir}/plugins/fused-ops-and-kernels
# run the benchmark script
bash scripts/run_benchmarks.sh {posargs:"1 2" benchmark_outputs}