diff --git a/README.md b/README.md index 916c91479..27a467e81 100644 --- a/README.md +++ b/README.md @@ -183,6 +183,12 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile. [CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md) +## Acknowledgement + +- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training +- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy + + ## License [BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE) diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index f2b5f7f58..b36f655ad 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -26,6 +26,7 @@ class MiniModelConfig: MINI_MODEL_SETUPS = { "mini_llama3": MiniModelConfig( + # TODO (easy): replace with oss public path tokenizer_path="/shared/public/models/Meta-Llama-3-8B/", liger_kernel_patch_func=apply_liger_kernel_to_llama, model_class=LlamaForCausalLM, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 98f15abfd..71b81d7f0 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -155,16 +155,8 @@ def test_correctness_with_ignore_index( @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - # (0.01, torch.bfloat16, 1e-8, 5e-2), - # (0.1, torch.bfloat16, 1e-8, 5e-2), (1.0, torch.bfloat16, 1e-8, 5e-2), - # (10.0, torch.bfloat16, 1e-8, 5e-2), - # (100.0, torch.bfloat16, 1e-8, 5e-2), - # (0.01, torch.float32, 1e-8, 1e-6), - # (0.1, torch.float32, 1e-8, 1e-6), (1.0, torch.float32, 1e-8, 1e-6), - # (10.0, torch.float32, 1e-8, 1e-6), - # (100.0, torch.float32, 1e-8, 1e-6), ], ) def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol): diff --git a/test/transformers/test_fused_linear_cross_entropy.py b/test/transformers/test_fused_linear_cross_entropy.py index 13af97a9e..c21070340 100644 --- a/test/transformers/test_fused_linear_cross_entropy.py +++ b/test/transformers/test_fused_linear_cross_entropy.py @@ -58,6 +58,7 @@ def forward(self, x, y): [ (2, 4, 512, 512), (8, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing # (4, 2048, 4096, 128256), # llama3 8B # (4, 1024, 8192, 128256), # llama3 70B (4, 423, 8192, 32000), # random shape diff --git a/test/transformers/test_geglu.py b/test/transformers/test_geglu.py index 5329f3ca7..b06fa04bf 100644 --- a/test/transformers/test_geglu.py +++ b/test/transformers/test_geglu.py @@ -12,8 +12,6 @@ ) SLEEP_SECONDS = 0.1 -# TODO (yun dai): triton 3.0.0 breaks geglu due to tanh module issue - @pytest.mark.parametrize( "bsz, seq_len, hidden_size, intermediate_size", diff --git a/test/transformers/test_rms_norm.py b/test/transformers/test_rms_norm.py index 0ede8c440..1d0a44905 100644 --- a/test/transformers/test_rms_norm.py +++ b/test/transformers/test_rms_norm.py @@ -72,5 +72,5 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol): ) is True ) - # import pdb; pdb.set_trace() + assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) is True diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index c9e0a6afe..14132c2a9 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -29,7 +29,7 @@ # atol is for small values: they have more difference, so set atol higher # rtol is for larger values: they are very close, so set rtol lower (torch.float32, 1e-0, 1e-5), - # TODO: we should find a better way to tune this lol. 1e4 is too large apparently + # TODO: we should find a better way to tune this. 1e4 is too large apparently (torch.bfloat16, 1e4, 1e-2), ], )