Skip to content

Commit

Permalink
Polish test/ and others (#26)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

as title

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

```
jobuser [ ~/Liger-Kernel ]$ make checkstyle && make test && make test-convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
45 files left unchanged.
pytest --disable-warnings test/ --ignore=test/convergence
================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 111 items                                                                                                                                                                                                      

test/transformers/test_cross_entropy.py ..........................................................                                                                                                                 [ 52%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                                        [ 57%]
test/transformers/test_geglu.py ........                                                                                                                                                                           [ 64%]
test/transformers/test_rms_norm.py ................                                                                                                                                                                [ 79%]
test/transformers/test_rope.py ............                                                                                                                                                                        [ 90%]
test/transformers/test_swiglu.py ........                                                                                                                                                                          [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                                                              [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                                         [100%]

============================================================================================= 111 passed in 60.64s (0:01:00) =============================================================================================
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
================================================================================================== test session starts ===================================================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                                                                                                        

test/convergence/test_mini_models.py ......                                                                                                                                                                        [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                                                                  [100%]

============================================================================================== 8 passed in 95.88s (0:01:35) ==============================================================================================
```
  • Loading branch information
ByronHsu authored Aug 15, 2024
1 parent f7f8384 commit 030cb71
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 12 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with Torch Compile.

[CONTRIBUTING GUIDE](https://github.com/linkedin/Liger-Kernel/blob/main/CONTRIBUTING.md)

## Acknowledgement

- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) for convergence testing by andrej karpathy


## License

[BSD 2-CLAUSE](https://github.com/linkedin/Liger-Kernel/blob/main/LICENSE)
Expand Down
1 change: 1 addition & 0 deletions test/convergence/test_mini_models_no_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MiniModelConfig:

MINI_MODEL_SETUPS = {
"mini_llama3": MiniModelConfig(
# TODO (easy): replace with oss public path
tokenizer_path="/shared/public/models/Meta-Llama-3-8B/",
liger_kernel_patch_func=apply_liger_kernel_to_llama,
model_class=LlamaForCausalLM,
Expand Down
8 changes: 0 additions & 8 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,8 @@ def test_correctness_with_ignore_index(
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
# (0.01, torch.bfloat16, 1e-8, 5e-2),
# (0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
# (10.0, torch.bfloat16, 1e-8, 5e-2),
# (100.0, torch.bfloat16, 1e-8, 5e-2),
# (0.01, torch.float32, 1e-8, 1e-6),
# (0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
# (10.0, torch.float32, 1e-8, 1e-6),
# (100.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol):
Expand Down
1 change: 1 addition & 0 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def forward(self, x, y):
[
(2, 4, 512, 512),
(8, 2048, 4096, 32000), # llama2, mistral
# Comment out to speed up testing
# (4, 2048, 4096, 128256), # llama3 8B
# (4, 1024, 8192, 128256), # llama3 70B
(4, 423, 8192, 32000), # random shape
Expand Down
2 changes: 0 additions & 2 deletions test/transformers/test_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
)
SLEEP_SECONDS = 0.1

# TODO (yun dai): triton 3.0.0 breaks geglu due to tanh module issue


@pytest.mark.parametrize(
"bsz, seq_len, hidden_size, intermediate_size",
Expand Down
2 changes: 1 addition & 1 deletion test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol):
)
is True
)
# import pdb; pdb.set_trace()

assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) is True
2 changes: 1 addition & 1 deletion test/transformers/test_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# atol is for small values: they have more difference, so set atol higher
# rtol is for larger values: they are very close, so set rtol lower
(torch.float32, 1e-0, 1e-5),
# TODO: we should find a better way to tune this lol. 1e4 is too large apparently
# TODO: we should find a better way to tune this. 1e4 is too large apparently
(torch.bfloat16, 1e4, 1e-2),
],
)
Expand Down

0 comments on commit 030cb71

Please sign in to comment.