Skip to content

Commit adb9117

Browse files
JasonZhu1313shimizustSunMarcByronHsu
authored
Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer (#32860)
* add liger integration * fix syntax * fix import issue * add trainer.md * Use _apply_liger_kernel() * Fixed log message * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <[email protected]> * Update docs/source/en/trainer.md Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <[email protected]> * Update src/transformers/trainer.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/training_args.py Co-authored-by: Byron Hsu <[email protected]> * Update docs/source/en/trainer.md Co-authored-by: Byron Hsu <[email protected]> * Fixed checkstyle and updated readme * Added test * Fixed checkstyle * fix docstring * rename use_liger to use_liger_kernel * Trigger Build * Added test * add fix-copies * Fixed copy inconsistencies --------- Co-authored-by: shimizust <[email protected]> Co-authored-by: Steven Shimizu <[email protected]> Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Byron Hsu <[email protected]>
1 parent 970a16e commit adb9117

File tree

7 files changed

+118
-0
lines changed

7 files changed

+118
-0
lines changed

docs/source/en/trainer.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,41 @@ trainer.train()
382382

383383
Note layerwise optimization is a bit experimental and does not support DDP (Distributed Data Parallel), thus you can run the training script only on a single GPU. Please see [this appropriate section](https://github.com/jiaweizzhao/GaLore?tab=readme-ov-file#train-7b-model-with-a-single-gpu-with-24gb-memory) for more details. Other features such as gradient clipping, DeepSpeed, etc might not be supported out of the box. Please [raise an issue on GitHub](https://github.com/huggingface/transformers/issues) if you encounter such issue.
384384

385+
## Liger Kernel
386+
387+
[Liger-Kernel](https://github.com/linkedin/Liger-Kernel) Kernel is a collection of Triton kernels developed by Linkedin designed specifically for LLM training. We have implemented Hugging Face Compatible RMSNorm, RoPE, SwiGLU, CrossEntropy, FusedLinearCrossEntropy, and more to come. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The kernel works out of the box with flash attention, PyTorch FSDP, and Microsoft DeepSpeed.
388+
389+
<Tip>
390+
Gain +20% throughput and reduce memory usage by 60% on LLaMA 3-8B model training. Achieve longer context lengths and larger batch sizes. It’s also useful if you want to scale up your model to multi-head training or large vocabulary sizes. Unleash multi-head training (medusa) and more. See details and examples in [Liger](https://github.com/linkedin/Liger-Kernel/tree/main/examples)
391+
</Tip>
392+
393+
First make sure to install Liger official repository:
394+
```bash
395+
pip install liger-kernel
396+
```
397+
398+
You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:
399+
400+
```py
401+
from transformers import TrainingArguments
402+
403+
training_args = TrainingArguments(
404+
output_dir="your-model",
405+
learning_rate=2e-5,
406+
per_device_train_batch_size=16,
407+
per_device_eval_batch_size=16,
408+
num_train_epochs=2,
409+
weight_decay=0.01,
410+
eval_strategy="epoch",
411+
save_strategy="epoch",
412+
load_best_model_at_end=True,
413+
push_to_hub=True,
414+
use_liger_kernel=True
415+
)
416+
```
417+
418+
The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
419+
385420
## LOMO optimizer
386421

387422
The LOMO optimizers have been introduced in [Full Parameter Fine-Tuning for Large Language Models with Limited Resources](https://hf.co/papers/2306.09782) and [AdaLomo: Low-memory Optimization with Adaptive Learning Rate](https://hf.co/papers/2310.10195).

src/transformers/testing_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
is_keras_nlp_available,
8585
is_levenshtein_available,
8686
is_librosa_available,
87+
is_liger_kernel_available,
8788
is_lomo_available,
8889
is_natten_available,
8990
is_nltk_available,
@@ -1162,6 +1163,13 @@ def require_librosa(test_case):
11621163
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
11631164

11641165

1166+
def require_liger_kernel(test_case):
1167+
"""
1168+
Decorator marking a test that requires liger_kernel
1169+
"""
1170+
return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
1171+
1172+
11651173
def require_essentia(test_case):
11661174
"""
11671175
Decorator marking a test that requires essentia

src/transformers/trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
is_grokadamw_available,
156156
is_in_notebook,
157157
is_ipex_available,
158+
is_liger_kernel_available,
158159
is_lomo_available,
159160
is_peft_available,
160161
is_safetensors_available,
@@ -464,6 +465,24 @@ def __init__(
464465
" to `True` to avoid any unexpected behavior such as device placement mismatching."
465466
)
466467

468+
if self.args.use_liger_kernel:
469+
if is_liger_kernel_available():
470+
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
471+
472+
model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
473+
if model_type:
474+
# Monkey patch the model with liger kernels. Use the default kernel configurations.
475+
_apply_liger_kernel(model_type=model_type)
476+
else:
477+
logger.warning(
478+
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
479+
)
480+
else:
481+
raise ImportError(
482+
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
483+
"Please install it with `pip install liger-kernel`"
484+
)
485+
467486
_is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
468487
model, "_hf_peft_config_loaded", False
469488
)

src/transformers/training_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,11 @@ class TrainingArguments:
793793
794794
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
795795
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
796+
797+
use_liger_kernel (`bool`, *optional*, defaults to `False`):
798+
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
799+
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
800+
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
796801
"""
797802

798803
framework = "pt"
@@ -1493,6 +1498,11 @@ class TrainingArguments:
14931498
},
14941499
)
14951500

1501+
use_liger_kernel: Optional[bool] = field(
1502+
default=False,
1503+
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
1504+
)
1505+
14961506
eval_use_gather_object: Optional[bool] = field(
14971507
default=False,
14981508
metadata={

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
is_keras_nlp_available,
149149
is_levenshtein_available,
150150
is_librosa_available,
151+
is_liger_kernel_available,
151152
is_lomo_available,
152153
is_mlx_available,
153154
is_natten_available,

src/transformers/utils/import_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
177177
_torchvision_available = _is_package_available("torchvision")
178178
_mlx_available = _is_package_available("mlx")
179179
_hqq_available = _is_package_available("hqq")
180+
_liger_kernel_available = _is_package_available("liger_kernel")
180181

181182

182183
_torch_version = "N/A"
@@ -1164,6 +1165,13 @@ def is_mlx_available():
11641165
return _mlx_available
11651166

11661167

1168+
def is_liger_kernel_available():
1169+
if not _liger_kernel_available:
1170+
return False
1171+
1172+
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
1173+
1174+
11671175
# docstyle-ignore
11681176
AV_IMPORT_ERROR = """
11691177
{0} requires the PyAv library but it was not found in your environment. You can install it with:

tests/trainer/test_trainer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
require_galore_torch,
6565
require_grokadamw,
6666
require_intel_extension_for_pytorch,
67+
require_liger_kernel,
6768
require_lomo,
6869
require_optuna,
6970
require_peft,
@@ -1325,6 +1326,42 @@ def test_get_eval_dataloader_with_persistent_workers(self):
13251326
self.assertEqual(first_dataloader, first_dataloader_repeated)
13261327
self.assertEqual(second_dataloader, second_dataloader_repeated)
13271328

1329+
@require_liger_kernel
1330+
def test_use_liger_kernel_patching(self):
1331+
# Test that the model code actually gets patched with Liger kernel
1332+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
1333+
1334+
from transformers.models.llama import modeling_llama
1335+
1336+
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
1337+
tiny_llama = LlamaForCausalLM(config)
1338+
1339+
args = TrainingArguments(
1340+
"./test",
1341+
use_liger_kernel=True,
1342+
)
1343+
Trainer(tiny_llama, args)
1344+
1345+
# Check that one of the Llama model layers has been correctly patched with Liger kernel
1346+
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
1347+
1348+
@require_liger_kernel
1349+
@require_torch_gpu
1350+
def test_use_liger_kernel_trainer(self):
1351+
# Check that trainer still works with liger kernel applied
1352+
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
1353+
tiny_llama = LlamaForCausalLM(config)
1354+
1355+
x = torch.randint(0, 100, (128,))
1356+
train_dataset = RepeatDataset(x)
1357+
1358+
with tempfile.TemporaryDirectory() as tmpdir:
1359+
args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True)
1360+
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
1361+
1362+
# Check this works
1363+
_ = trainer.train()
1364+
13281365
@require_lomo
13291366
@require_torch_gpu
13301367
def test_lomo(self):

0 commit comments

Comments
 (0)