Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip Tests for GPUs Not Supporting bf16 #159

Merged
merged 4 commits into from
Aug 29, 2024

Conversation

austin362667
Copy link
Contributor

@austin362667 austin362667 commented Aug 29, 2024

Summary

Closes #87

Skipped tests for bfloat16 on GPUs with compute capability below Ampere architecture (sm_80).

Testing Done

  • Hardware Type: NVIDIA T4 (should skip most cases)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
⚡ main ~/Liger-Kernel make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
=================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... =================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... Skipped 1 files
All done! ✨ 🍰 ✨
58 files left unchanged.
collected 163 items                                                                                                                                        

test/transformers/test_auto_model.py .                                                                                                               [  0%]
test/transformers/test_cross_entropy.py ssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                   [ 36%]
collected 28 items                                                                                                                                         

test/convergence/test_mini_models.py .....s.....s....                                                                                    [ 43%]
test/transformers/test_geglu.py .s....ssss                                                                                                             [ 48%]
test/transformers/test_monkey_patch.py .....                                                                                                         [ 51%]
test/transformers/test_rms_norm.py ........ssssssss...............ssssssss........                                                                  [ 80%]
test/transformers/test_rope.py ......ssssss                                                                                                          [ 88%]
test/transformers/test_swiglu.py ....ssss.s....ssss                                                                                                    [ 98%]
test/transformers/test_trainer_integration.py .                                                                                                      [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                           [100%]

======================================================== 71 passed, 92 skipped in 136.69s (0:02:16) ========================================================
.s.s.s                                                                                                  [ 50%]
test/convergence/test_mini_models_no_logits.py .s.s.s.s.s.s.s                                                                                        [100%]

======================================================== 14 passed, 14 skipped in 353.27s (0:05:53) ========================================================
  • Hardware Type: NVIDIA L4 (should skip few cases)
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
⚡ main ~/Liger-Kernel make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
=================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... =================================================================== test session starts ====================================================================
platform linux -- Python 3.10.10, pytest-8.3.2, pluggy-1.5.0
rootdir: /teamspace/studios/this_studio/Liger-Kernel
plugins: anyio-4.4.0
collecting ... Skipped 1 files
All done! ✨ 🍰 ✨
58 files left unchanged.
collected 163 items                                                                                                                                        

test/transformers/test_auto_model.py .                                                                                                               [  0%]
collected 28 items                                                                                                                                         

test/convergence/test_mini_models.py ........................................................ss                                                   [ 36%]
test/transformers/test_fused_linear_cross_entropy.py ...............                                                                                    [ 43%]
test/transformers/test_geglu.py .........                                                                                                             [ 48%]
test/transformers/test_monkey_patch.py .....                                                                                                         [ 51%]
test/transformers/test_rms_norm.py .................................................                                                                  [ 80%]
test/transformers/test_rope.py ............                                                                                                          [ 88%]
test/transformers/test_swiglu.py ..................                                                                                                    [ 98%]
test/transformers/test_trainer_integration.py .                                                                                                      [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                           [100%]

======================================================== 161 passed, 2 skipped in 90.45s (0:01:30) =========================================================
.......                                                                                                  [ 50%]
test/convergence/test_mini_models_no_logits.py ..............                                                                                        [100%]

============================================================== 28 passed in 290.65s (0:04:50) ==============================================================

Additional Context

FYR, here’s a list of NVIDIA architecture names, and which compute capabilities they have:

Screenshot 2024-08-29 at 6 04 56 PM

@ByronHsu
Copy link
Collaborator

import pytest
import torch
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

def supports_bfloat16():
    if not torch.cuda.is_available():
        return False
    return torch.cuda.get_device_capability() >= (8, 0)  # Ampere and newer

@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 4096, 32000),  # llama2, mistral
        (2, 4096, 32000),  # llama2, mistral
        (1, 4096, 128256),  # llama3
        # weird shapes
        (3, 423, 32000),
    ],
)
@pytest.mark.parametrize(
    "scalar, dtype, atol, rtol",
    [
        pytest.param(0.1, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
        pytest.param(1.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
        pytest.param(10.0, torch.bfloat16, 1e-7, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
        (0.1, torch.float32, 1e-8, 1e-6),
        (1.0, torch.float32, 1e-8, 1e-6),
        (10.0, torch.float32, 1e-8, 1e-6),
    ],
)
def test_correctness(B, T, V, scalar, dtype, atol, rtol):
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

    liger_ce = LigerCrossEntropyLoss()
    test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)

def test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol):
    # Implement your test logic here
    # This is a placeholder implementation
    logits = torch.randn(B, T, V, device="cuda", dtype=dtype) * scalar
    labels = torch.randint(0, V, (B, T), device="cuda")

    # Your existing test logic goes here
    # For example:
    # loss = liger_ce(logits, labels)
    # expected_loss = torch.nn.functional.cross_entropy(logits.float(), labels)
    # torch.testing.assert_close(loss, expected_loss, atol=atol, rtol=rtol)

    # For now, we'll just use a placeholder assertion
    assert True, "Test passed"

We can do this but simplify a bit. (P.S. thanks claude)

@austin362667 austin362667 changed the title Add compute capability marker to skip tests run on old GPU arch Skipped tests for bf16 Aug 29, 2024
@austin362667 austin362667 changed the title Skipped tests for bf16 Skip Tests for GPUs Not Supporting bf16 Aug 29, 2024
@austin362667
Copy link
Contributor Author

Sure, updated. Thanks for reviewing. That’s neat — much appreciated!

Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. cc @lancerts @helloworld1 to do a 2nd pass

Copy link
Collaborator

@lancerts lancerts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@lancerts lancerts merged commit cbc4f85 into linkedin:main Aug 29, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add pytest filtering based on GPU types
3 participants