Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pr-test-amd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ jobs:
fail-fast: false
matrix:
runner: [linux-mi35x-gpu-8]
part: [0, 1]
part: [0, 1, 2]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout code
Expand All @@ -679,7 +679,7 @@ jobs:
- name: Run test
timeout-minutes: 60
run: |
bash scripts/ci/amd/amd_ci_exec.sh -w "/sglang-checkout/test" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd-mi35x --auto-partition-id ${{ matrix.part }} --auto-partition-size 2 --timeout-per-file 3600
bash scripts/ci/amd/amd_ci_exec.sh -w "/sglang-checkout/test" python3 run_suite.py --hw amd --suite stage-c-test-large-8-gpu-amd-mi35x --auto-partition-id ${{ matrix.part }} --auto-partition-size 3 --timeout-per-file 3600

stage-b-test-small-1-gpu-performance-amd:
needs: [check-changes, call-gate, stage-a-test-1-amd]
Expand Down
95 changes: 95 additions & 0 deletions test/registered/amd/test_kimi_k2_instruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import unittest
from types import SimpleNamespace
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To robustly parse the base URL, it's recommended to use Python's built-in urlparse function. Please add the necessary import. This will make the URL handling in the test methods less brittle and more maintainable.

Suggested change
from types import SimpleNamespace
from types import SimpleNamespace
from urllib.parse import urlparse


import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_amd_ci
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import (
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)

register_amd_ci(est_time=3600, suite="stage-c-test-large-8-gpu-amd-mi35x")

KIMI_K2_MODEL_PATH = "moonshotai/Kimi-K2-Instruct-0905"
SERVER_LAUNCH_TIMEOUT = 3600


class TestKimiK2Instruct0905(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = KIMI_K2_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
other_args = [
"--tp",
"8",
"--decode-attention-backend",
"triton",
"--prefill-attention-backend",
"aiter",
"--trust-remote-code",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
]
env = os.environ.copy()
env["SGLANG_USE_AITER"] = "1"
env["SGLANG_ROCM_FUSED_DECODE_MLA"] = "0"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=SERVER_LAUNCH_TIMEOUT,
other_args=other_args,
env=env,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1319,
parallel=1319,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
Comment on lines +61 to +69
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Hardcoding the host URL and manually parsing the port is brittle. For instance, if DEFAULT_URL_FOR_TEST were to use localhost instead of an IP address, this test could fail. Using urlparse (with the import added at the top of the file) to deconstruct self.base_url is a more robust approach.

Suggested change
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1319,
parallel=1319,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
parsed_url = urlparse(self.base_url)
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1319,
parallel=1319,
max_new_tokens=512,
host=f"{parsed_url.scheme}://{parsed_url.hostname}",
port=parsed_url.port,
)

metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (Kimi-K2-Instruct-0905)\n"
f'{metrics["accuracy"]=:.3f}\n'
)
self.assertGreater(metrics["accuracy"], 0.94)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to test_a_gsm8k, manually parsing the port from the URL is brittle. Using urlparse provides a more robust way to extract the port and improves maintainability.

Suggested change
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
parsed_url = urlparse(self.base_url)
args = BenchArgs(port=parsed_url.port, max_new_tokens=2048)

_, speed = send_one_prompt(args)

print(f"{speed=:.2f}")

if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (Kimi-K2-Instruct-0905)\n"
f"{speed=:.2f} token/s\n"
)
self.assertGreater(speed, 45)


if __name__ == "__main__":
unittest.main()
Loading