-
Notifications
You must be signed in to change notification settings - Fork 208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Skip Tests for GPUs Not Supporting bf16
#159
Skip Tests for GPUs Not Supporting bf16
#159
Conversation
import pytest
import torch
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
def supports_bfloat16():
if not torch.cuda.is_available():
return False
return torch.cuda.get_device_capability() >= (8, 0) # Ampere and newer
@pytest.mark.parametrize(
"B, T, V",
[
(2, 4096, 32000), # llama2, mistral
(2, 4096, 32000), # llama2, mistral
(1, 4096, 128256), # llama3
# weird shapes
(3, 423, 32000),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(0.1, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
pytest.param(1.0, torch.bfloat16, 1e-8, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
pytest.param(10.0, torch.bfloat16, 1e-7, 5e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU")),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
def test_correctness(B, T, V, scalar, dtype, atol, rtol):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
liger_ce = LigerCrossEntropyLoss()
test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol)
def test_correctness_once(liger_ce, B, T, V, scalar, dtype, atol, rtol):
# Implement your test logic here
# This is a placeholder implementation
logits = torch.randn(B, T, V, device="cuda", dtype=dtype) * scalar
labels = torch.randint(0, V, (B, T), device="cuda")
# Your existing test logic goes here
# For example:
# loss = liger_ce(logits, labels)
# expected_loss = torch.nn.functional.cross_entropy(logits.float(), labels)
# torch.testing.assert_close(loss, expected_loss, atol=atol, rtol=rtol)
# For now, we'll just use a placeholder assertion
assert True, "Test passed" We can do this but simplify a bit. (P.S. thanks claude) |
bf16
bf16
bf16
Sure, updated. Thanks for reviewing. That’s neat — much appreciated! |
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
Signed-off-by: Austin Liu <[email protected]>
23788d1
to
b58562d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. cc @lancerts @helloworld1 to do a 2nd pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Summary
Closes #87
Skipped tests for
bfloat16
on GPUs with compute capability below Ampere architecture (sm_80
).Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergencemake test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergenceAdditional Context
FYR, here’s a list of NVIDIA architecture names, and which compute capabilities they have: