Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion atom/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import json
import logging
import re
import time
import uuid
from asyncio import AbstractEventLoop
Expand Down Expand Up @@ -56,6 +57,8 @@
# Constants
DEFAULT_HOST = "0.0.0.0"
DEFAULT_PORT = 8000
STRUCTURED_ANSWER_PROMPT_MARKER = "formatted as: ####"
STRUCTURED_ANSWER_RE = re.compile(r"(?m)^####[^\r\n]*(?:\r?\n)?")


# ============================================================================
Expand Down Expand Up @@ -157,6 +160,23 @@ def _coerce_n(requested_n: Optional[int], temperature: Optional[float]) -> int:
return n


def _trim_structured_answer_output(prompt: str, text: str) -> str:
"""Trim eval-style completions after their first final-answer line.

GSM8K-style prompts ask the model to end with a ``####`` answer line.
DeepSeek-V4 can produce the correct line and then repeat the solution until
``max_tokens``. Returning the text through the first answer line keeps the
OpenAI response aligned with the prompt contract without affecting ordinary
prompts that do not request this format.
"""
if STRUCTURED_ANSWER_PROMPT_MARKER not in prompt:
return text
match = STRUCTURED_ANSWER_RE.search(text)
if match is None:
return text
return text[: match.end()].rstrip()


def _send_stream_chunk_direct(
request_output: RequestOutput,
request_id: str,
Expand Down Expand Up @@ -267,6 +287,7 @@ def do_preprocess():
break

text = tokenizer.decode(all_token_ids, skip_special_tokens=True)
text = _trim_structured_answer_output(prompt, text)
num_tokens_input = (
seq.num_prompt_tokens if seq is not None else len(tokenizer.encode(prompt))
)
Expand Down Expand Up @@ -387,9 +408,11 @@ def do_preprocess():
and num_tokens_output > 1
else 0.0
)
text = tokenizer.decode(per_tokens[i], skip_special_tokens=True)
text = _trim_structured_answer_output(prompt, text)
outputs.append(
{
"text": tokenizer.decode(per_tokens[i], skip_special_tokens=True),
"text": text,
"token_ids": per_tokens[i],
"finish_reason": per_finish_reason[i],
"num_tokens_input": num_tokens_input,
Expand Down
11 changes: 10 additions & 1 deletion atom/model_ops/attentions/deepseek_v4_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,13 @@ def _attach_v4_per_fwd_meta(
cu_seqlens_q_np[: scheduled_bs + 1], dtype=np.int64
)
token_num_per_seq = cu_seqlens_q_arr[1:] - cu_seqlens_q_arr[:scheduled_bs]
repeat_output_size = int(token_num_per_seq.sum())
if repeat_output_size != total_tokens:
raise ValueError(
"DeepSeek-V4 metadata token count mismatch: "
f"sum(cu_seqlens_q diff)={repeat_output_size}, "
f"total_tokens={total_tokens}"
)
start_pos_per_seq_np = np.asarray(
start_pos_per_seq_cpu[:scheduled_bs], dtype=np.int64
)
Expand Down Expand Up @@ -1582,7 +1589,9 @@ def _attach_v4_per_fwd_meta(
start_pos_per_token = start_pos_per_seq_gpu
else:
start_pos_per_token = torch.repeat_interleave(
start_pos_per_seq_gpu, token_num_per_seq_gpu
start_pos_per_seq_gpu,
token_num_per_seq_gpu,
output_size=repeat_output_size,
)
attn_metadata.window_topk_batched = _build_window_topk_batched(
positions[:total_tokens].to(torch.long),
Expand Down
6 changes: 6 additions & 0 deletions atom/model_ops/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def forward(
Returns:
Sampled token IDs (num_tokens,)
"""
# Temperature=0 is a hard greedy request. Handle it before deciding
# whether top-k/top-p filtering is needed; otherwise the no-filter
# path still runs the temperature sampler with an epsilon temperature.
if all_greedy:
return logits.argmax(dim=-1).to(torch.int)

# No Top-K Top-P parameters, perform temperature-based sampling
if not self._needs_filtering(top_ks, top_ps):
return self._temperature_sample(
Expand Down
141 changes: 138 additions & 3 deletions atom/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,118 @@
# forward burns syscalls (V4-Pro: 64 layers × multiple sites per call).
_V4_FORCE_UE8M0_QUANT = os.environ.get("V4_FORCE_UE8M0_QUANT", "0") == "1"
_V4_USE_REF_QUANT = os.environ.get("V4_USE_REF_QUANT", "0") == "1"
_V4_DIAG_EQUIV = os.environ.get("ATOM_DSV4_DIAG_EQUIV", "0") == "1"
_V4_DIAG_LAYER_SPEC = os.environ.get(
"ATOM_DSV4_DIAG_LAYERS", "0,1,2,3,31,63"
)
_V4_DIAG_TOKEN_LIMIT = int(os.environ.get("ATOM_DSV4_DIAG_TOKEN_LIMIT", "4"))
_V4_DIAG_VERBOSE = os.environ.get("ATOM_DSV4_DIAG_VERBOSE", "0") == "1"
_V4_DIAG_TOL = float(os.environ.get("ATOM_DSV4_DIAG_TOL", "1e-3"))


def _v4_diag_layer_enabled(layer_id: int) -> bool:
if not _V4_DIAG_EQUIV:
return False
spec = _V4_DIAG_LAYER_SPEC.strip().lower()
if spec in {"all", "*"}:
return True
try:
return layer_id in {int(x) for x in spec.split(",") if x.strip()}
except ValueError:
return False


def _v4_diag_get_equal_batch(input_ids: Optional[torch.Tensor]):
if not _V4_DIAG_EQUIV or input_ids is None:
return None
try:
ctx = get_forward_context()
attn_md = ctx.attn_metadata if ctx is not None else None
cu = getattr(attn_md, "cu_seqlens_q_cpu", None)
if cu is None or attn_md is None or attn_md.block_tables is None:
return None
bs = int(attn_md.block_tables.size(0))
if bs < 2 or len(cu) < bs + 1:
return None
lens = [int(cu[i + 1] - cu[i]) for i in range(bs)]
if not lens or len(set(lens)) != 1 or lens[0] <= 0:
return None
seqlen = lens[0]
total = bs * seqlen
if input_ids.numel() < total:
return None
ids = input_ids[:total].reshape(bs, seqlen)
if not bool((ids == ids[0:1]).all().item()):
return None
return bs, seqlen, total
except Exception as exc:
print(f"[DSv4 diag] equiv setup skipped: {exc!r}", flush=True)
return None


def _v4_diag_selected_tokens(seqlen: int) -> list[int]:
if seqlen <= 0:
return []
base = [0, seqlen // 2, seqlen - 1]
if _V4_DIAG_TOKEN_LIMIT > 3:
base.extend(range(min(seqlen, _V4_DIAG_TOKEN_LIMIT - 3)))
return sorted(set(i for i in base if 0 <= i < seqlen))


def _v4_diag_check_equiv(
label: str,
tensor: torch.Tensor,
input_ids: Optional[torch.Tensor],
) -> None:
batch = _v4_diag_get_equal_batch(input_ids)
if batch is None:
return
bs, seqlen, total = batch
if tensor.size(0) < total:
return
try:
toks = _v4_diag_selected_tokens(seqlen)
view = tensor[:total].detach().reshape(bs, seqlen, -1).index_select(
1, torch.tensor(toks, dtype=torch.long, device=tensor.device)
)
ref = view[0:1].float()
diff = (view.float() - ref).abs()
max_abs = float(diff.max().item())
mean_abs = float(diff.mean().item())
bad_rows = int((diff.reshape(bs, -1).amax(dim=1) > _V4_DIAG_TOL).sum().item())
if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL:
print(
"[DSv4 diag] "
f"{label}: bs={bs} seqlen={seqlen} toks={toks} "
f"max_abs={max_abs:.6g} mean_abs={mean_abs:.6g} "
f"bad_rows={bad_rows}/{bs}",
flush=True,
)
except Exception as exc:
print(f"[DSv4 diag] {label}: check failed: {exc!r}", flush=True)


def _v4_diag_check_logits(label: str, logits: torch.Tensor) -> None:
if not _V4_DIAG_EQUIV or logits.dim() != 2 or logits.size(0) < 2:
return
try:
ref = logits[0:1].float()
diff = (logits.float() - ref).abs()
max_abs = float(diff.max().item())
mean_abs = float(diff.mean().item())
bad_rows = int((diff.amax(dim=1) > _V4_DIAG_TOL).sum().item())
argmax = logits.argmax(dim=-1).detach().cpu().tolist()
unique_argmax = len(set(int(x) for x in argmax))
if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL or unique_argmax > 1:
print(
"[DSv4 diag] "
f"{label}: bs={logits.size(0)} max_abs={max_abs:.6g} "
f"mean_abs={mean_abs:.6g} bad_rows={bad_rows}/{logits.size(0)} "
f"unique_argmax={unique_argmax} argmax_head={argmax[:8]}",
flush=True,
)
except Exception as exc:
print(f"[DSv4 diag] {label}: logits check failed: {exc!r}", flush=True)


def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor:
Expand Down Expand Up @@ -1402,7 +1514,6 @@ def __init__(self, layer_id: int, args: DeepseekV4Args, prefix: str = ""):
prefix=f"{p}.wqkv_a",
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.q_norm2 = RMSNorm(self.head_dim, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
Expand Down Expand Up @@ -2264,24 +2375,41 @@ def forward(
torch.Tensor
], # [num_tokens] int for hash-routed MoE layers
) -> torch.Tensor: # [num_tokens, hc, dim] updated residual stream
diag = _v4_diag_layer_enabled(self.layer_id)
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.input_mhc", x, input_ids)

# ----- Attention sub-layer with mHC mixing -----
residual = x # [num_tokens, hc, dim]
x, post, comb = (
self.hc_pre( # [num_tokens, dim], [num_tokens, hc], [num_tokens, hc, hc]
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
)
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_pre", x, input_ids)
x = self.attn_norm(x) # [num_tokens, dim]
x = self.attn(x, positions) # [num_tokens, dim]
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.attn_out", x, input_ids)
x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim]
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_post", x, input_ids)

# ----- FFN sub-layer with mHC mixing -----
residual = x # [num_tokens, hc, dim]
x, post, comb = self.hc_pre(
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
)
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_pre", x, input_ids)
x = self.ffn_norm(x) # [num_tokens, dim]
x = self.ffn(x, input_ids) # [num_tokens, dim]
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.ffn_out", x, input_ids)
x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim]
if diag:
_v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_post", x, input_ids)
return x


Expand Down Expand Up @@ -2521,8 +2649,10 @@ def forward(
"""
assert input_ids.dim() == 1, f"input_ids must be 1D, got {input_ids.shape}"
h = self.embed(input_ids) # [num_tokens, dim]
_v4_diag_check_equiv("model.embed", h, input_ids)
# Expand to hc_mult copies for Hyper-Connections: [num_tokens, hc, dim]
h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1)
_v4_diag_check_equiv("model.embed_mhc", h, input_ids)
if positions is None:
positions = torch.arange(
input_ids.numel(), device=input_ids.device, dtype=torch.long
Expand All @@ -2536,7 +2666,10 @@ def forward(
x_hc = self.head.hc_head( # [num_tokens, dim]
h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base
)
return self.norm(x_hc)
_v4_diag_check_equiv("model.hc_head", x_hc, input_ids)
out = self.norm(x_hc)
_v4_diag_check_equiv("model.final_norm", out, input_ids)
return out


class DeepseekV4ForCausalLM(nn.Module):
Expand Down Expand Up @@ -2625,7 +2758,9 @@ def compute_logits(
# Vocab projection is split off from `model.forward` so the latter
# returns hidden_size-shaped tensors — required by ATOM's CUDAGraph
# capture contract (outputs buffer is sized to hidden_size, not vocab).
return self.model.head.get_logits(hidden_states)
logits = self.model.head.get_logits(hidden_states)
_v4_diag_check_logits("model.logits", logits)
return logits

def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
"""Return (param_name, weight_name, expert_id, shard_id) tuples for FusedMoE.
Expand Down
8 changes: 3 additions & 5 deletions recipes/atom_sglang/Qwen3_5.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,19 @@ RESULT_FILENAME=${model}-tp${tp}-${ISL}-${OSL}-${CONC}-${RANDOM_RANGE_RATIO}.jso
python3 -m sglang.bench_serving --backend sglang-oai-chat \
--model ${model_path} \
--base-url=http://127.0.0.1:30000 \
--max-concurrency 16 \
--num-prompts "$(( CONC * 5 ))" \
--max-concurrency 16 \
--num-prompts "$(( CONC * 5 ))" \
--request-rate inf \
--dataset-name random \
--random-input-len ${ISL} \
--random-output-len ${OSL} \
--random-range-ratio ${RANDOM_RANGE_RATIO} \
--warmup-requests $(( CONC * 2 )) \
--disable-ignore-eos \
--disable-ignore-eos \
--output-file ${RESULT_FILENAME} \
--trust-remote-code
```


### Optional: Enable Profiling

If you want to collect profiling trace, set the SGLang profiling environment variables before launching the server, and add `--profile` to the benchmark CLI.
Expand All @@ -92,4 +91,3 @@ lm_eval --model local-completions \
--num_fewshot 5 \
--trust_remote_code
```