Skip to content

Commit

Permalink
Skip Tests for GPUs Not Supporting bf16 (#159)
Browse files Browse the repository at this point in the history
## Summary

Closes #87

Skipped tests for `bfloat16` on GPUs with compute capability below
Ampere architecture (`sm_80`).

<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: NVIDIA **T4** (should skip most cases)
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

```
⚡ main ~/Liger-Kernel make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
=================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... =================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... Skipped 1 files
All done! ✨ 🍰 ✨
58 files left unchanged.
collected 163 items                                                                                                                                        

test/transformers/test_auto_model.py .                                                                                                               [  0%]
test/transformers/test_cross_entropy.py ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                   [ 36%]
collected 28 items                                                                                                                                         

test/convergence/test_mini_models.py .....s.....s....                                                                                    [ 43%]
test/transformers/test_geglu.py .s....ssss                                                                                                             [ 48%]
test/transformers/test_monkey_patch.py .....                                                                                                         [ 51%]
test/transformers/test_rms_norm.py ........ssssssss...............ssssssss........                                                                  [ 80%]
test/transformers/test_rope.py ......ssssss                                                                                                          [ 88%]
test/transformers/test_swiglu.py ....ssss.s....ssss                                                                                                    [ 98%]
test/transformers/test_trainer_integration.py .                                                                                                      [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                           [100%]

======================================================== 71 passed, 92 skipped in 136.69s (0:02:16) ========================================================
.s.s.s                                                                                                  [ 50%]
test/convergence/test_mini_models_no_logits.py .s.s.s.s.s.s.s                                                                                        [100%]

======================================================== 14 passed, 14 skipped in 353.27s (0:05:53) ========================================================
```

- Hardware Type: NVIDIA **L4** (should skip few cases)
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

```
⚡ main ~/Liger-Kernel make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
=================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... =================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... Skipped 1 files
All done! ✨ 🍰 ✨
58 files left unchanged.
collected 163 items                                                                                                                                        

test/transformers/test_auto_model.py .                                                                                                               [  0%]
collected 28 items                                                                                                                                         

test/convergence/test_mini_models.py ........................................................ss                                                   [ 36%]
test/transformers/test_fused_linear_cross_entropy.py ...............                                                                                    [ 43%]
test/transformers/test_geglu.py .........                                                                                                             [ 48%]
test/transformers/test_monkey_patch.py .....                                                                                                         [ 51%]
test/transformers/test_rms_norm.py .................................................                                                                  [ 80%]
test/transformers/test_rope.py ............                                                                                                          [ 88%]
test/transformers/test_swiglu.py ..................                                                                                                    [ 98%]
test/transformers/test_trainer_integration.py .                                                                                                      [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                           [100%]

======================================================== 161 passed, 2 skipped in 90.45s (0:01:30) =========================================================
.......                                                                                                  [ 50%]
test/convergence/test_mini_models_no_logits.py ..............                                                                                        [100%]

============================================================== 28 passed in 290.65s (0:04:50) ==============================================================
```

##  Additional Context
FYR, here’s a list of NVIDIA architecture names, and which compute
capabilities they have:

<img width="1268" alt="Screenshot 2024-08-29 at 6 04 56 PM"
src="https://github.com/user-attachments/assets/6675ae9e-9137-4adb-8af7-ee1226733353">

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: Shao Tang <[email protected]>
  • Loading branch information
austin362667 and lancerts authored Aug 29, 2024
1 parent e5d6ad7 commit cbc4f85
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 27 deletions.
113 changes: 106 additions & 7 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
assert_verbose_allclose,
set_seed,
simple_collate_fn,
supports_bfloat16,
)

import pytest
Expand Down Expand Up @@ -344,23 +345,121 @@ def run_mini_model(
[
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 6e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1.1",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_llama3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine
# TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass.
# ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5),
# ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_mistral",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_qwen2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_phi3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
],
)
def test_mini_model(
Expand Down
113 changes: 106 additions & 7 deletions test/convergence/test_mini_models_no_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
assert_verbose_allclose,
set_seed,
simple_collate_fn,
supports_bfloat16,
)

import pytest
Expand Down Expand Up @@ -291,20 +292,118 @@ def run_mini_model(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_llama3",
32,
1e-4,
torch.bfloat16,
5e-3,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_qwen2",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_phi3",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-2, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_mistral",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma1.1",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
pytest.param(
"mini_gemma2",
32,
1e-4,
torch.bfloat16,
1e-2,
1e-4,
2e-1,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
],
)
def test_mini_model(
Expand Down
84 changes: 77 additions & 7 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from test.utils import supports_bfloat16

import pytest
import torch
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -99,14 +101,42 @@ def _test_correctness_not_last_layer_once(
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
(10.0, torch.bfloat16, 1e-7, 5e-2),
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-7,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness(B, T, V, scalar, dtype, atol, rtol):
liger_ce = LigerCrossEntropyLoss()
_test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)
Expand All @@ -125,14 +155,42 @@ def test_correctness(B, T, V, scalar, dtype, atol, rtol):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(0.1, torch.bfloat16, 1e-8, 5e-2),
(1.0, torch.bfloat16, 1e-8, 5e-2),
(10.0, torch.bfloat16, 1e-8, 5e-2),
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_ignore_index(
B, T, V, ignore_index, scalar, dtype, atol, rtol
):
Expand All @@ -155,10 +213,22 @@ def test_correctness_with_ignore_index(
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 1e-8, 5e-2),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(1.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_not_last_layer(B, T, V, scalar, dtype, atol, rtol):
liger_ce = LigerCrossEntropyLoss()
_test_correctness_not_last_layer_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)
Expand Down
Loading

0 comments on commit cbc4f85

Please sign in to comment.