Perf tuning and expansion of cases covered for wvSplitKrc#33493
Perf tuning and expansion of cases covered for wvSplitKrc#33493vllm-bot merged 7 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
There was a problem hiding this comment.
Code Review
This pull request introduces performance tuning for the wvSplitKrc kernel and expands the cases it covers. The changes are mainly in csrc/rocm/skinny_gemms.cu, with corresponding updates in the dispatch logic in vllm/model_executor/layers/utils.py and test cases in tests/kernels/quantization/test_rocm_skinny_gemms.py. While the performance optimizations seem promising, I've identified a few critical issues. There's a logic mismatch between the Python dispatch code and the C++ kernel implementation that could lead to incorrect kernel dispatching. Additionally, a crucial out-of-bounds check appears to have been incorrectly removed in the kernel, which could lead to incorrect computations. I've provided detailed comments on these issues.
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
|
Is this PR related to #33527? |
No they are not related. They make changes to different skinny GEMMs. This PR targets these test scenarios (where cross-wave atomic reduction is used to fill machine, cases seen in gpt-oss). FYI there'll be another similar PR soon targeting padded activation in the non-quantized skinny GEMM solution. |
I'll launch a CI cycle tomorrow then with this one to see if there is any test regressing.
When you post it, CC me if possible so we can prevent any possible regressions :) There has been a huge effort to keep AMD CI green. |
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
@AndreasKaratzas This is the 3rd PR. It adds padding support to the fp16/bf16 version of skinny gemm solutions. |
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
|
@amd-hhashemi Thanks for the kernel improvements in this PR. We've been investigating a persistent test failure in Estimated root cause: The old code had a conditional direct-write path that was deterministic: bool doRdc = (kfitsPerRdc * kFit < K);When This PR changed it to: bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always trueand removed the entire Reproduction: After reverting this PR, both tests passed:
Is it possible to have deterministic kernels, and only if there is an option passed, something like |
|
@AndreasKaratzas Hi, can you please try with #34410? It's a one-liner. I root-caused an issue that shows up on some vLLM dockers on N<=16 GEMMs (as seen in single prompt gptoss). It was occuring only in non-eager mode prompts for me, and I was never able to get an out of threshold test on any of the GEMM sizes. |
|
@amd-hhashemi Back at it again 😅 So there is another inaccuracy observed test_cudagraph_divergence.pyimport math
from dataclasses import dataclass
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.distributed import cleanup_dist_env_and_memory
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
MAX_MODEL_LEN = 256
SEED = 42
GPU_MEM_UTIL = 0.4
MAX_LOGPROBS = 5
TOP_LOGPROBS = 3
MAX_TOKENS = 10
PROMPT = "Hello world " * 50
@dataclass
class RunConfig:
name: str
enforce_eager: bool
compile_ranges_split_points: list[int] | None # None = use default
def make_sampling_params():
normal = SamplingParams(
temperature=0,
logprobs=TOP_LOGPROBS,
max_tokens=MAX_TOKENS,
ignore_eos=False,
)
penalty = SamplingParams(
temperature=0,
logprobs=TOP_LOGPROBS,
max_tokens=MAX_TOKENS,
ignore_eos=False,
presence_penalty=-1.0,
)
return normal, penalty
def run_config(config: RunConfig):
print(f"\n{'='*60}")
print(f"Running: {config.name}")
print(f" enforce_eager={config.enforce_eager}")
print(f" compile_ranges_split_points={config.compile_ranges_split_points}")
print(f"{'='*60}")
kwargs = dict(
model=MODEL,
max_logprobs=MAX_LOGPROBS,
max_model_len=MAX_MODEL_LEN,
seed=SEED,
gpu_memory_utilization=GPU_MEM_UTIL,
enable_prefix_caching=False,
enable_chunked_prefill=True,
max_num_batched_tokens=32,
enforce_eager=config.enforce_eager,
)
if config.compile_ranges_split_points is not None and not config.enforce_eager:
kwargs["compilation_config"] = CompilationConfig(
compile_ranges_split_points=config.compile_ranges_split_points,
)
llm = LLM(**kwargs)
normal_params, penalty_params = make_sampling_params()
results = llm.generate(
[PROMPT, PROMPT], [normal_params, penalty_params]
)
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
return results
def extract_logprobs(results):
per_request = []
for result in results:
positions = []
for lp_dict in result.outputs[0].logprobs:
positions.append(lp_dict)
per_request.append(positions)
return per_request
def compare(name_a, lps_a, name_b, lps_b):
labels = ["no_penalty", "with_penalty"]
print(f"\n{'#'*70}")
print(f"COMPARISON: {name_a} vs {name_b}")
print(f"{'#'*70}")
max_diff = 0.0
total = 0
fail_5 = 0
fail_10 = 0
for req_idx in range(len(lps_a)):
label = labels[req_idx]
a = lps_a[req_idx]
b = lps_b[req_idx]
if len(a) != len(b):
print(f" [{label}] LENGTH MISMATCH: {len(a)} vs {len(b)}")
continue
print(f"\n [{label}] {len(a)} positions")
print(f" {'pos':>4} {'token':>15} {'rank':>5} "
f"{'lp_A':>12} {'lp_B':>12} {'diff':>10} {'rel%':>8}")
print(f" {'-'*72}")
for pos in range(len(a)):
common = set(a[pos].keys()) & set(b[pos].keys())
for tid in sorted(common):
la = a[pos][tid]
lb = b[pos][tid]
diff = abs(la.logprob - lb.logprob)
denom = max(abs(la.logprob), abs(lb.logprob), 1e-10)
rel = (diff / denom) * 100
max_diff = max(max_diff, diff)
total += 1
c5 = math.isclose(la.logprob, lb.logprob,
rel_tol=5e-2, abs_tol=1e-1)
c10 = math.isclose(la.logprob, lb.logprob,
rel_tol=1e-1, abs_tol=1e-1)
if not c5:
fail_5 += 1
if not c10:
fail_10 += 1
flag = ""
if not c5:
flag = " <-- FAIL@5%"
if not c10:
flag = " <-- FAIL@10%"
print(f" {pos:>4} {la.decoded_token!r:>15} "
f"{la.rank:>3} "
f"{la.logprob:>12.6f} {lb.logprob:>12.6f} "
f"{diff:>10.6f} {rel:>7.2f}%{flag}")
print(f"\n SUMMARY: {total} comparisons, max_diff={max_diff:.6f}, "
f"fail@5%={fail_5}, fail@10%={fail_10}")
return max_diff, fail_5, fail_10
def main():
configs = [
RunConfig("eager", enforce_eager=True, compile_ranges_split_points=None),
RunConfig("eager2", enforce_eager=True, compile_ranges_split_points=None),
RunConfig("graph_sp32", enforce_eager=False, compile_ranges_split_points=[32]),
RunConfig("graph_sp64", enforce_eager=False, compile_ranges_split_points=[64]),
]
all_results = {}
all_lps = {}
for cfg in configs:
all_results[cfg.name] = run_config(cfg)
all_lps[cfg.name] = extract_logprobs(all_results[cfg.name])
comparisons = [
("eager", "eager2", "Eager vs Eager (sanity: expect zero diff)"),
("eager", "graph_sp32", "Eager vs CUDA graph split=[32]"),
("eager", "graph_sp64", "Eager vs CUDA graph split=[64]"),
("graph_sp32", "graph_sp64", "CUDA graph split=[32] vs split=[64] (the confound)"),
]
print(f"\n\n{'='*70}")
print("ALL COMPARISONS")
print(f"{'='*70}")
summary = []
for a, b, desc in comparisons:
print(f"\n--- {desc} ---")
md, f5, f10 = compare(a, all_lps[a], b, all_lps[b])
summary.append((desc, md, f5, f10))
print(f"\n\n{'='*70}")
print("FINAL SUMMARY")
print(f"{'='*70}")
print(f"\n{'Description':<55} {'MaxDiff':>10} {'F@5%':>6} {'F@10%':>6}")
print(f"{'-'*80}")
for desc, md, f5, f10 in summary:
print(f"{desc:<55} {md:>10.6f} {f5:>6} {f10:>6}")
if __name__ == "__main__":
main()If you run this with default settings, i.e., Running with `VLLM_ROCM_USE_SKINNY_GEMM=1`But if you run it with Running with `VLLM_ROCM_USE_SKINNY_GEMM=0`These experiments were conducted on MI355 machine. Btw, I would like some help with revamping the skinny GEMMs test since these failures should be caught there. Can you help with those tasks? EDIT: While the above is a custom script, the motivation for this was the |
…ct#33493) Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
…ct#33493) Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
mi355 measurements before and after changes:
m, n, K , bfor(us), aftr (us)
128, 16, 2880, 4.55, 4.56
640, 16, 2880, 4.80, 4.83
128, 32, 2880, 3.91, 3.21
640, 32, 2880, 4.13, 4.05
128, 64, 2880, 4.42, 3.23
640, 64, 2880, 4.88, 4.43
128, 128, 2880, 4.51, 3.98
640, 128, 2880, 5.89, 5.92
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.