Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ jobs:
docker build \
--build-arg SG_LANG_KERNEL_BRANCH=${{ github.head_ref }} \
--build-arg SG_LANG_KERNEL_REPO=${{ github.event.pull_request.head.repo.clone_url }} \
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:pvc .
--no-cache --progress=plain -f Dockerfile.xpu_kernel -t xpu_sglang:kernel .

- name: Run container
run: |
docker run -dt \
--device /dev/dri/ \
--name ci_sglang_xpu \
-e HF_TOKEN=$(cat ~/huggingface_token.txt) \
xpu_sglang:pvc
xpu_sglang:kernel

- name: Install Dependency
timeout-minutes: 20
Expand All @@ -45,17 +45,11 @@ jobs:
docker exec ci_sglang_xpu /bin/bash -c '/miniforge3/envs/py3.10/bin/huggingface-cli login --token ${HF_TOKEN} '
docker exec ci_sglang_xpu /bin/bash -c "ln -sf /miniforge3/envs/py3.10/bin/python3 /usr/bin/python3"

- name: Run Sglang Kernel Cases
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 -m pytest -v -s test_awq_dequant.py test_topk_softmax.py test_flash_attention.py"

- name: Run Sglang Kernel Benchmarks
timeout-minutes: 20
run: |
docker exec -w /root/sglang ci_sglang_xpu \
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/benchmark && python3 bench_flash_attn.py "
/bin/bash -c "cd /root/sglang/sgl-kernel-xpu/tests && python3 run_suite.py --suite per-commit "

- name: Run E2E Bfloat16 tests
timeout-minutes: 20
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.xpu_kernel
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ RUN --mount=type=secret,id=github_token \
cd sgl-kernel-xpu && \
pip install -v . &&\
# Install required packages for sglang workloads
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore && \
pip install msgspec blake3 py-cpuinfo compressed_tensors gguf partial_json_parser einops matplotlib pandas --root-user-action=ignore aiohttp && \
conda install libsqlite=3.48.0 -y && \
echo ". /miniforge3/bin/activate; conda activate py${PYTHON_VERSION}; cd /root/" >> /root/.bashrc;

Expand Down
121 changes: 121 additions & 0 deletions tests/run_suite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import argparse
import glob
from dataclasses import dataclass

from test_utils import run_unittest_files


@dataclass
class TestFile:
name: str
estimated_time: float = 60


# Add Intel XPU Kernel tests
suites = {
"per-commit": [
TestFile("test_awq_dequant.py"),
TestFile("test_topk_softmax.py"),
],
}


def auto_partition(files, rank, size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate more what does this function do?

Copy link
Collaborator Author

@DiweiSun DiweiSun Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is fully ported from sglang main, aiming to achieve load balance when tests workload is heavy. we may not use this function for now. Shall I remove this function?

"""
Partition files into size sublists with approximately equal sums of estimated times
using stable sorting, and return the partition for the specified rank.

Args:
files (list): List of file objects with estimated_time attribute
rank (int): Index of the partition to return (0 to size-1)
size (int): Number of partitions

Returns:
list: List of file objects in the specified rank's partition
"""
weights = [f.estimated_time for f in files]

if not weights or size <= 0 or size > len(weights):
return []

# Create list of (weight, original_index) tuples
# Using negative index as secondary key to maintain original order for equal weights
indexed_weights = [(w, -i) for i, w in enumerate(weights)]
# Stable sort in descending order by weight
# If weights are equal, larger (negative) index comes first (i.e., earlier original position)
indexed_weights = sorted(indexed_weights, reverse=True)

# Extract original indices (negate back to positive)
indexed_weights = [(w, -i) for w, i in indexed_weights]

# Initialize partitions and their sums
partitions = [[] for _ in range(size)]
sums = [0.0] * size

# Greedy approach: assign each weight to partition with smallest current sum
for weight, idx in indexed_weights:
# Find partition with minimum sum
min_sum_idx = sums.index(min(sums))
partitions[min_sum_idx].append(idx)
sums[min_sum_idx] += weight

# Return the files corresponding to the indices in the specified rank's partition
indices = partitions[rank]
return [files[i] for i in indices]


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=1800,
help="The time limit for running one file in seconds.",
)
arg_parser.add_argument(
"--suite",
type=str,
default=list(suites.keys())[0],
choices=list(suites.keys()) + ["all"],
help="The suite to run",
)
arg_parser.add_argument(
"--range-begin",
type=int,
default=0,
help="The begin index of the range of the files to run.",
)
arg_parser.add_argument(
"--range-end",
type=int,
default=None,
help="The end index of the range of the files to run.",
)
arg_parser.add_argument(
"--auto-partition-id",
type=int,
help="Use auto load balancing. The part id.",
)
arg_parser.add_argument(
"--auto-partition-size",
type=int,
help="Use auto load balancing. The number of parts.",
)
args = arg_parser.parse_args()
print(f"{args=}")

if args.suite == "all":
files = glob.glob("**/test_*.py", recursive=True)
else:
files = suites[args.suite]

if args.auto_partition_size:
files = auto_partition(files, args.auto_partition_id, args.auto_partition_size)
else:
files = files[args.range_begin : args.range_end]

print("The running tests are ", [f.name for f in files])

exit_code = run_unittest_files(files, args.timeout_per_file)

exit(exit_code)
91 changes: 91 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Common utilities for testing and benchmarking"""

import os
import subprocess
import threading
import time
from typing import Callable, List, Optional


class TestFile:
name: str
estimated_time: float = 60


def run_with_timeout(
func: Callable,
args: tuple = (),
kwargs: Optional[dict] = None,
timeout: float = None,
):
"""Run a function with timeout."""
ret_value = []

def _target_func():
ret_value.append(func(*args, **(kwargs or {})))

t = threading.Thread(target=_target_func)
t.start()
t.join(timeout=timeout)
if t.is_alive():
raise TimeoutError()

if not ret_value:
raise RuntimeError()

return ret_value[0]


def run_unittest_files(files: List[TestFile], timeout_per_file: float):
tic = time.perf_counter()
success = True

for i, file in enumerate(files):
filename, estimated_time = file.name, file.estimated_time
process = None

def run_one_file(filename):
nonlocal process

filename = os.path.join(os.getcwd(), filename)
print(
f".\n.\nBegin ({i}/{len(files) - 1}):\npython3 {filename}\n.\n.\n",
flush=True,
)
tic = time.perf_counter()

process = subprocess.Popen(
["python3", filename], stdout=None, stderr=None, env=os.environ
)
process.wait()
elapsed = time.perf_counter() - tic

print(
f".\n.\nEnd ({i}/{len(files) - 1}):\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
flush=True,
)
return process.returncode

try:
ret_code = run_with_timeout(
run_one_file, args=(filename,), timeout=timeout_per_file
)
assert (
ret_code == 0
), f"expected return code 0, but {filename} returned {ret_code}"
except TimeoutError:
kill_process_tree(process.pid)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
flush=True,
)
success = False
break

if success:
print(f"Success. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)
else:
print(f"Fail. Time elapsed: {time.perf_counter() - tic:.2f}s", flush=True)

return 0 if success else -1