Skip to content

Conversation

@adarshxs
Copy link
Collaborator

Motivation

For: #4690

Modifications

test files in: https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests

Checklist

@adarshxs adarshxs requested a review from zhyncs March 26, 2025 17:37
@zhyncs
Copy link
Member

zhyncs commented Mar 26, 2025

speculative/test_eagle_utils.py and speculative/test_speculative_sampling.py should also be updated

@adarshxs
Copy link
Collaborator Author

@zhyncs one test is failing. Increasing tolerance should help?

@zhyncs zhyncs requested a review from FlamingoPg as a code owner March 29, 2025 17:02
@FlamingoPg
Copy link
Collaborator

nice work~

Copy link
Collaborator

@FlamingoPg FlamingoPg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~

)
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
def test_grouped_gemm_accuracy(out_dtype):
Ms = [1, 16, 32, 256, 1024]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For different shapes you should use @pytest.mark.parametrize instead of for loop

)
def test_gemm():
print("Testing GEMM:")
for m in (64, 128, 4096):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


def test_m_grouped_gemm_contiguous():
print("Testing grouped contiguous GEMM:")
for num_groups, m, k, n in (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same



def test_accuracy():
Ms = [1, 128, 512, 1024, 4096]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same



def test_accuracy():
Ms = [1, 128, 512, 1024, 4096]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same



def test_accuracy():
Ms = [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@adarshxs adarshxs closed this Mar 29, 2025
@adarshxs
Copy link
Collaborator Author

adarshxs commented Mar 29, 2025

apologies messed up. opening a new PR

@adarshxs adarshxs deleted the pytest_transition branch April 19, 2025 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants