|
64 | 64 | require_galore_torch,
|
65 | 65 | require_grokadamw,
|
66 | 66 | require_intel_extension_for_pytorch,
|
| 67 | + require_liger_kernel, |
67 | 68 | require_lomo,
|
68 | 69 | require_optuna,
|
69 | 70 | require_peft,
|
@@ -1325,6 +1326,42 @@ def test_get_eval_dataloader_with_persistent_workers(self):
|
1325 | 1326 | self.assertEqual(first_dataloader, first_dataloader_repeated)
|
1326 | 1327 | self.assertEqual(second_dataloader, second_dataloader_repeated)
|
1327 | 1328 |
|
| 1329 | + @require_liger_kernel |
| 1330 | + def test_use_liger_kernel_patching(self): |
| 1331 | + # Test that the model code actually gets patched with Liger kernel |
| 1332 | + from liger_kernel.transformers.rms_norm import LigerRMSNorm |
| 1333 | + |
| 1334 | + from transformers.models.llama import modeling_llama |
| 1335 | + |
| 1336 | + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) |
| 1337 | + tiny_llama = LlamaForCausalLM(config) |
| 1338 | + |
| 1339 | + args = TrainingArguments( |
| 1340 | + "./test", |
| 1341 | + use_liger_kernel=True, |
| 1342 | + ) |
| 1343 | + Trainer(tiny_llama, args) |
| 1344 | + |
| 1345 | + # Check that one of the Llama model layers has been correctly patched with Liger kernel |
| 1346 | + self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm) |
| 1347 | + |
| 1348 | + @require_liger_kernel |
| 1349 | + @require_torch_gpu |
| 1350 | + def test_use_liger_kernel_trainer(self): |
| 1351 | + # Check that trainer still works with liger kernel applied |
| 1352 | + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) |
| 1353 | + tiny_llama = LlamaForCausalLM(config) |
| 1354 | + |
| 1355 | + x = torch.randint(0, 100, (128,)) |
| 1356 | + train_dataset = RepeatDataset(x) |
| 1357 | + |
| 1358 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 1359 | + args = TrainingArguments(tmpdir, learning_rate=1e-2, logging_steps=5, max_steps=20, use_liger_kernel=True) |
| 1360 | + trainer = Trainer(tiny_llama, args, train_dataset=train_dataset) |
| 1361 | + |
| 1362 | + # Check this works |
| 1363 | + _ = trainer.train() |
| 1364 | + |
1328 | 1365 | @require_lomo
|
1329 | 1366 | @require_torch_gpu
|
1330 | 1367 | def test_lomo(self):
|
|
0 commit comments