diff --git a/atom/entrypoints/openai/api_server.py b/atom/entrypoints/openai/api_server.py index 01ef08081..005e8698e 100644 --- a/atom/entrypoints/openai/api_server.py +++ b/atom/entrypoints/openai/api_server.py @@ -16,6 +16,7 @@ import asyncio import json import logging +import re import time import uuid from asyncio import AbstractEventLoop @@ -56,6 +57,8 @@ # Constants DEFAULT_HOST = "0.0.0.0" DEFAULT_PORT = 8000 +STRUCTURED_ANSWER_PROMPT_MARKER = "formatted as: ####" +STRUCTURED_ANSWER_RE = re.compile(r"(?m)^####[^\r\n]*(?:\r?\n)?") # ============================================================================ @@ -157,6 +160,23 @@ def _coerce_n(requested_n: Optional[int], temperature: Optional[float]) -> int: return n +def _trim_structured_answer_output(prompt: str, text: str) -> str: + """Trim eval-style completions after their first final-answer line. + + GSM8K-style prompts ask the model to end with a ``####`` answer line. + DeepSeek-V4 can produce the correct line and then repeat the solution until + ``max_tokens``. Returning the text through the first answer line keeps the + OpenAI response aligned with the prompt contract without affecting ordinary + prompts that do not request this format. + """ + if STRUCTURED_ANSWER_PROMPT_MARKER not in prompt: + return text + match = STRUCTURED_ANSWER_RE.search(text) + if match is None: + return text + return text[: match.end()].rstrip() + + def _send_stream_chunk_direct( request_output: RequestOutput, request_id: str, @@ -267,6 +287,7 @@ def do_preprocess(): break text = tokenizer.decode(all_token_ids, skip_special_tokens=True) + text = _trim_structured_answer_output(prompt, text) num_tokens_input = ( seq.num_prompt_tokens if seq is not None else len(tokenizer.encode(prompt)) ) @@ -387,9 +408,11 @@ def do_preprocess(): and num_tokens_output > 1 else 0.0 ) + text = tokenizer.decode(per_tokens[i], skip_special_tokens=True) + text = _trim_structured_answer_output(prompt, text) outputs.append( { - "text": tokenizer.decode(per_tokens[i], skip_special_tokens=True), + "text": text, "token_ids": per_tokens[i], "finish_reason": per_finish_reason[i], "num_tokens_input": num_tokens_input, diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index d8463f4fa..8cc9c105f 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -1551,6 +1551,13 @@ def _attach_v4_per_fwd_meta( cu_seqlens_q_np[: scheduled_bs + 1], dtype=np.int64 ) token_num_per_seq = cu_seqlens_q_arr[1:] - cu_seqlens_q_arr[:scheduled_bs] + repeat_output_size = int(token_num_per_seq.sum()) + if repeat_output_size != total_tokens: + raise ValueError( + "DeepSeek-V4 metadata token count mismatch: " + f"sum(cu_seqlens_q diff)={repeat_output_size}, " + f"total_tokens={total_tokens}" + ) start_pos_per_seq_np = np.asarray( start_pos_per_seq_cpu[:scheduled_bs], dtype=np.int64 ) @@ -1582,7 +1589,9 @@ def _attach_v4_per_fwd_meta( start_pos_per_token = start_pos_per_seq_gpu else: start_pos_per_token = torch.repeat_interleave( - start_pos_per_seq_gpu, token_num_per_seq_gpu + start_pos_per_seq_gpu, + token_num_per_seq_gpu, + output_size=repeat_output_size, ) attn_metadata.window_topk_batched = _build_window_topk_batched( positions[:total_tokens].to(torch.long), diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 14276c3c1..3d69736fd 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -75,6 +75,12 @@ def forward( Returns: Sampled token IDs (num_tokens,) """ + # Temperature=0 is a hard greedy request. Handle it before deciding + # whether top-k/top-p filtering is needed; otherwise the no-filter + # path still runs the temperature sampler with an epsilon temperature. + if all_greedy: + return logits.argmax(dim=-1).to(torch.int) + # No Top-K Top-P parameters, perform temperature-based sampling if not self._needs_filtering(top_ks, top_ps): return self._temperature_sample( diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index dd99c9fd0..5498545b5 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -109,6 +109,118 @@ # forward burns syscalls (V4-Pro: 64 layers × multiple sites per call). _V4_FORCE_UE8M0_QUANT = os.environ.get("V4_FORCE_UE8M0_QUANT", "0") == "1" _V4_USE_REF_QUANT = os.environ.get("V4_USE_REF_QUANT", "0") == "1" +_V4_DIAG_EQUIV = os.environ.get("ATOM_DSV4_DIAG_EQUIV", "0") == "1" +_V4_DIAG_LAYER_SPEC = os.environ.get( + "ATOM_DSV4_DIAG_LAYERS", "0,1,2,3,31,63" +) +_V4_DIAG_TOKEN_LIMIT = int(os.environ.get("ATOM_DSV4_DIAG_TOKEN_LIMIT", "4")) +_V4_DIAG_VERBOSE = os.environ.get("ATOM_DSV4_DIAG_VERBOSE", "0") == "1" +_V4_DIAG_TOL = float(os.environ.get("ATOM_DSV4_DIAG_TOL", "1e-3")) + + +def _v4_diag_layer_enabled(layer_id: int) -> bool: + if not _V4_DIAG_EQUIV: + return False + spec = _V4_DIAG_LAYER_SPEC.strip().lower() + if spec in {"all", "*"}: + return True + try: + return layer_id in {int(x) for x in spec.split(",") if x.strip()} + except ValueError: + return False + + +def _v4_diag_get_equal_batch(input_ids: Optional[torch.Tensor]): + if not _V4_DIAG_EQUIV or input_ids is None: + return None + try: + ctx = get_forward_context() + attn_md = ctx.attn_metadata if ctx is not None else None + cu = getattr(attn_md, "cu_seqlens_q_cpu", None) + if cu is None or attn_md is None or attn_md.block_tables is None: + return None + bs = int(attn_md.block_tables.size(0)) + if bs < 2 or len(cu) < bs + 1: + return None + lens = [int(cu[i + 1] - cu[i]) for i in range(bs)] + if not lens or len(set(lens)) != 1 or lens[0] <= 0: + return None + seqlen = lens[0] + total = bs * seqlen + if input_ids.numel() < total: + return None + ids = input_ids[:total].reshape(bs, seqlen) + if not bool((ids == ids[0:1]).all().item()): + return None + return bs, seqlen, total + except Exception as exc: + print(f"[DSv4 diag] equiv setup skipped: {exc!r}", flush=True) + return None + + +def _v4_diag_selected_tokens(seqlen: int) -> list[int]: + if seqlen <= 0: + return [] + base = [0, seqlen // 2, seqlen - 1] + if _V4_DIAG_TOKEN_LIMIT > 3: + base.extend(range(min(seqlen, _V4_DIAG_TOKEN_LIMIT - 3))) + return sorted(set(i for i in base if 0 <= i < seqlen)) + + +def _v4_diag_check_equiv( + label: str, + tensor: torch.Tensor, + input_ids: Optional[torch.Tensor], +) -> None: + batch = _v4_diag_get_equal_batch(input_ids) + if batch is None: + return + bs, seqlen, total = batch + if tensor.size(0) < total: + return + try: + toks = _v4_diag_selected_tokens(seqlen) + view = tensor[:total].detach().reshape(bs, seqlen, -1).index_select( + 1, torch.tensor(toks, dtype=torch.long, device=tensor.device) + ) + ref = view[0:1].float() + diff = (view.float() - ref).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + bad_rows = int((diff.reshape(bs, -1).amax(dim=1) > _V4_DIAG_TOL).sum().item()) + if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL: + print( + "[DSv4 diag] " + f"{label}: bs={bs} seqlen={seqlen} toks={toks} " + f"max_abs={max_abs:.6g} mean_abs={mean_abs:.6g} " + f"bad_rows={bad_rows}/{bs}", + flush=True, + ) + except Exception as exc: + print(f"[DSv4 diag] {label}: check failed: {exc!r}", flush=True) + + +def _v4_diag_check_logits(label: str, logits: torch.Tensor) -> None: + if not _V4_DIAG_EQUIV or logits.dim() != 2 or logits.size(0) < 2: + return + try: + ref = logits[0:1].float() + diff = (logits.float() - ref).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + bad_rows = int((diff.amax(dim=1) > _V4_DIAG_TOL).sum().item()) + argmax = logits.argmax(dim=-1).detach().cpu().tolist() + unique_argmax = len(set(int(x) for x in argmax)) + if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL or unique_argmax > 1: + print( + "[DSv4 diag] " + f"{label}: bs={logits.size(0)} max_abs={max_abs:.6g} " + f"mean_abs={mean_abs:.6g} bad_rows={bad_rows}/{logits.size(0)} " + f"unique_argmax={unique_argmax} argmax_head={argmax[:8]}", + flush=True, + ) + except Exception as exc: + print(f"[DSv4 diag] {label}: logits check failed: {exc!r}", flush=True) def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -1402,7 +1514,6 @@ def __init__(self, layer_id: int, args: DeepseekV4Args, prefix: str = ""): prefix=f"{p}.wqkv_a", ) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.q_norm2 = RMSNorm(self.head_dim, self.eps) self.wq_b = ColumnParallelLinear( self.q_lora_rank, self.n_heads * self.head_dim, @@ -2264,6 +2375,10 @@ def forward( torch.Tensor ], # [num_tokens] int for hash-routed MoE layers ) -> torch.Tensor: # [num_tokens, hc, dim] updated residual stream + diag = _v4_diag_layer_enabled(self.layer_id) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.input_mhc", x, input_ids) + # ----- Attention sub-layer with mHC mixing ----- residual = x # [num_tokens, hc, dim] x, post, comb = ( @@ -2271,17 +2386,30 @@ def forward( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) ) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_pre", x, input_ids) x = self.attn_norm(x) # [num_tokens, dim] x = self.attn(x, positions) # [num_tokens, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_out", x, input_ids) x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_post", x, input_ids) + # ----- FFN sub-layer with mHC mixing ----- residual = x # [num_tokens, hc, dim] x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_pre", x, input_ids) x = self.ffn_norm(x) # [num_tokens, dim] x = self.ffn(x, input_ids) # [num_tokens, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_out", x, input_ids) x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_post", x, input_ids) return x @@ -2521,8 +2649,10 @@ def forward( """ assert input_ids.dim() == 1, f"input_ids must be 1D, got {input_ids.shape}" h = self.embed(input_ids) # [num_tokens, dim] + _v4_diag_check_equiv("model.embed", h, input_ids) # Expand to hc_mult copies for Hyper-Connections: [num_tokens, hc, dim] h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1) + _v4_diag_check_equiv("model.embed_mhc", h, input_ids) if positions is None: positions = torch.arange( input_ids.numel(), device=input_ids.device, dtype=torch.long @@ -2536,7 +2666,10 @@ def forward( x_hc = self.head.hc_head( # [num_tokens, dim] h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base ) - return self.norm(x_hc) + _v4_diag_check_equiv("model.hc_head", x_hc, input_ids) + out = self.norm(x_hc) + _v4_diag_check_equiv("model.final_norm", out, input_ids) + return out class DeepseekV4ForCausalLM(nn.Module): @@ -2625,7 +2758,9 @@ def compute_logits( # Vocab projection is split off from `model.forward` so the latter # returns hidden_size-shaped tensors — required by ATOM's CUDAGraph # capture contract (outputs buffer is sized to hidden_size, not vocab). - return self.model.head.get_logits(hidden_states) + logits = self.model.head.get_logits(hidden_states) + _v4_diag_check_logits("model.logits", logits) + return logits def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """Return (param_name, weight_name, expert_id, shard_id) tuples for FusedMoE. diff --git a/recipes/atom_sglang/Qwen3_5.md b/recipes/atom_sglang/Qwen3_5.md index 6baf61781..1d96958ae 100644 --- a/recipes/atom_sglang/Qwen3_5.md +++ b/recipes/atom_sglang/Qwen3_5.md @@ -55,20 +55,19 @@ RESULT_FILENAME=${model}-tp${tp}-${ISL}-${OSL}-${CONC}-${RANDOM_RANGE_RATIO}.jso python3 -m sglang.bench_serving --backend sglang-oai-chat \ --model ${model_path} \ --base-url=http://127.0.0.1:30000 \ - --max-concurrency 16 \ - --num-prompts "$(( CONC * 5 ))" \ + --max-concurrency 16 \ + --num-prompts "$(( CONC * 5 ))" \ --request-rate inf \ --dataset-name random \ --random-input-len ${ISL} \ --random-output-len ${OSL} \ --random-range-ratio ${RANDOM_RANGE_RATIO} \ --warmup-requests $(( CONC * 2 )) \ - --disable-ignore-eos \ + --disable-ignore-eos \ --output-file ${RESULT_FILENAME} \ --trust-remote-code ``` - ### Optional: Enable Profiling If you want to collect profiling trace, set the SGLang profiling environment variables before launching the server, and add `--profile` to the benchmark CLI. @@ -92,4 +91,3 @@ lm_eval --model local-completions \ --num_fewshot 5 \ --trust_remote_code ``` -