Skip to content

Batch generation refactoring and various fixes#1072

Merged
angeloskath merged 34 commits intomainfrom
step-by-step-gen
Apr 1, 2026
Merged

Batch generation refactoring and various fixes#1072
angeloskath merged 34 commits intomainfrom
step-by-step-gen

Conversation

@angeloskath
Copy link
Copy Markdown
Member

This PR refactors the batch generator.

  • It is now simpler (or at least spread out to individual objects with more distinct responsibilities)
  • Allows for arbitrary segments in the prompt to facilitate checkpointing
  • Allows for stepping the prompt processing as well without the need for callbacks
  • All prompt processing is now done with right padding and decoding with left padding. As a result we can stop processing sequences as soon as they are done. Previously inserting a sequence of 100 tokens and 10,000 tokens would process ~20,000 tokens, now we will process ~12,000.
  • It introduces a StateMachine to handle transitions to thinking or tool or stopping with arbitrary token sequences efficiently. This is currently only used in batch mode should be used everywhere.
  • Enables system prompt checkpointing in the server.

Important bug-fixes

  • Qwen 3.5 batch mode fixed
    • The conv state was grabbed incorrectly when in batch mode
    • The gated delta net kernel would leave uninitialized memory in the output which would result in NaNs in the next full attention
  • Deepseek DSA batch mode fixed (affects GLM5 and Deepseek v3.2)
    • The mask was ignored in batch mode

I will link this PR to open issues that are being fixed by this instead of the opposite as I think it will be simpler.

@Thump604
Copy link
Copy Markdown

Code Review

Went through the full diff. The refactor is well-structured -- the separation into PromptProcessingBatch and GenerationBatch with distinct padding strategies is a clean design. Notes below.

Model fixes -- correct

The qwen3_5.py conv state fix is right. Per-sequence take_along_axis with clipped lengths correctly handles the right-padded batch case where different sequences have different real data lengths. The gated_delta.py zeroing of masked timesteps fixes a real NaN source from uninitialized memory in padded positions.

SequenceStateMachine -- nice upgrade

Aho-Corasick trie for multi-token stop sequence matching is a significant improvement over the flat stop_tokens set. Sharing the trie via __deepcopy__ (immutable structure) is efficient.

Potential bugs

1. PromptTrie.search() -- best can be None (medium)

In the longer detection branch, best is initialized to None and only set if a trie node has __value__. If nodes exist past the match point but none have values (e.g., after partial cleanup via pop_prefixes), best stays None:

longer = tokens[:index] + best  # TypeError: can only concatenate list to list

Suggest guarding: if best is not None: longer = tokens[:index] + best

2. CacheOrder.pop() -- no empty check (low)

def pop(self):
    i = 0
    while i + 1 < len(self._ordering):
        ...
    return lru_b.popleft()  # IndexError if all tiers empty

Called from trim_to which checks self._n_bytes > n_bytes, so it should never fire with empty tiers. But if _n_bytes drifts out of sync with actual cache contents, this crashes without a useful error.

3. Server tokenizer variable scoping (low)

In _generate(), the batch-extension branch (adding to existing batch) uses tokenizer in _make_state_machine() but the correct local is current_tokenizer. They're the same object in practice since both come from the same model load, but it's fragile -- if model hot-swapping is ever added, this becomes a real bug.

4. ArraysCache.extend() doesn't propagate lengths (low)

The cat() helper handles None entries, but self.lengths is not concatenated with other.lengths. If both caches have lengths set during batched prompt processing, the extended cache loses other's length information.

Design question -- thinking checkpoint detection

The _tokenize method scans backwards for think_start_id / think_end_id in the last 11 tokens. If a thinking block is longer than 11 tokens from the end, it won't be detected as a separate segment. Is this intentional? Thinking blocks in multi-turn conversations can be hundreds of tokens. The 11-token window seems like it's targeting "just started thinking" but would miss "mid-think with long reasoning."

BatchRotatingKVCache.merge() fix

The switch from c._idx to lengths (via c.size()) for sizing the destination slice is a real bug fix. _idx doesn't account for rotation past max_size, so the old code could write past allocated bounds. Good catch.

Overall this is solid. The right-pad prompt / left-pad decode split is a meaningful compute savings for mixed-length batches, and the StateMachine is the right abstraction for token sequence detection.

@angeloskath angeloskath merged commit 3f9d179 into main Apr 1, 2026
2 checks passed
@angeloskath angeloskath deleted the step-by-step-gen branch April 1, 2026 22:07
@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 1, 2026

Reviewed the latest commits (efb30e2..ee44cd2). The Qwen3.5 and GDN batching fixes look correct:

  • qwen3_5.py: per-sequence conv cache extraction via take_along_axis with cache.lengths handles variable-length batches properly. The mx.contiguous fallback matches the fix: break shared-buffer memory leak in GatedDeltaNet cache #1077 leak fix.
  • gated_delta.py: the else { y[dv_idx] = 0 } branch for inactive mask positions prevents undefined output in batched GDN, which was a correctness bug.
  • cache.py: None-safe filter/extend/merge across ArraysCache, BatchKVCache, and RotatingKVCache is thorough. Empty cache early-returns prevent unnecessary work.

I run Qwen3.5-122B-A10B (12 attention + 36 GDN layers) in continuous batching on M2 Ultra. These fixes address the exact hybrid cache issues I've been patching around in vllm-mlx. Looking forward to rebasing on this when it lands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants