-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[Feature] use pytest for sgl-kernel #4697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
|
@zhyncs one test is failing. Increasing tolerance should help? |
|
nice work~ |
FlamingoPg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM~
| ) | ||
| @pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) | ||
| def test_grouped_gemm_accuracy(out_dtype): | ||
| Ms = [1, 16, 32, 256, 1024] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For different shapes you should use @pytest.mark.parametrize instead of for loop
sgl-kernel/tests/test_deep_gemm.py
Outdated
| ) | ||
| def test_gemm(): | ||
| print("Testing GEMM:") | ||
| for m in (64, 128, 4096): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
sgl-kernel/tests/test_deep_gemm.py
Outdated
|
|
||
| def test_m_grouped_gemm_contiguous(): | ||
| print("Testing grouped contiguous GEMM:") | ||
| for num_groups, m, k, n in ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
|
|
||
|
|
||
| def test_accuracy(): | ||
| Ms = [1, 128, 512, 1024, 4096] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
sgl-kernel/tests/test_fp8_gemm.py
Outdated
|
|
||
|
|
||
| def test_accuracy(): | ||
| Ms = [1, 128, 512, 1024, 4096] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
sgl-kernel/tests/test_int8_gemm.py
Outdated
|
|
||
|
|
||
| def test_accuracy(): | ||
| Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
|
apologies messed up. opening a new PR |
Motivation
For: #4690
Modifications
test files in: https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests
Checklist