Conversation
WalkthroughAdds a KCP design doc and updates CP internals and tests: FP32 intermediate accumulation and explicit cross-rank all_gather in chunk_delta_h, plus distributed, gate-aware CP tests and revised test harness/assertions. Changes
Sequence Diagram(s)sequenceDiagram
participant RankA as Rank A
participant RankB as Rank B
participant AllGather as All-Gather Sync
participant PreProc as Pre-process Kernel
participant MainKernel as Main Kernel
RankA->>PreProc: compute local m, hm (fp32)
RankB->>PreProc: compute local m, hm (fp32)
PreProc->>AllGather: all_gather_into_tensor(ag_hm, ag_dhm)
AllGather->>RankA: ag_hm, ag_dhm
AllGather->>RankB: ag_hm, ag_dhm
RankA->>MainKernel: main_kernel(ag_hm, fp32 state, gated inputs)
RankB->>MainKernel: main_kernel(ag_hm, fp32 state, gated inputs + inter-rank decay)
MainKernel->>RankA: outputs, grads
MainKernel->>RankB: outputs, grads
Estimated code review effort🎯 4 (Complex) | ⏱️ ~55 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
|
Summary of ChangesHello @zhiyuan1i, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive documentation for the Kimi Context Parallel (KCP) feature, clarifying its architecture and operational differences between GDN and KDA. A critical precision fix ensures the stability of the M matrix chain multiplication by enforcing Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive documentation for Kimi Context Parallel (KCP), fixes a critical precision issue in the M matrix chain computation, and performs a major cleanup and enhancement of the context parallel tests. The new KCP.md file provides an excellent overview of the architecture. The precision fix in chunk_delta_h.py to enforce fp32 accumulation is crucial for numerical stability. The test refactoring in test_cp_conv.py, test_cp_gdn.py, and test_cp_kda.py significantly improves their robustness, coverage, and maintainability by using a centralized assert_close utility, more rigorous reference implementations, and better test configurations. Overall, this is a high-quality contribution that improves correctness, documentation, and testing. I have one minor suggestion for a typo in the documentation.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
fla/ops/cp/chunk_delta_h.py (1)
268-272:⚠️ Potential issue | 🟠 Majorfp32 precision fix not applied to the non-merged stage2 kernel, which is used in benchmarks.
The merged kernels (
pre_process_fwd_kernel_mergedat line 508 andpre_process_bwd_kernel_mergedat line 1010) correctly castb_m_iandb_mtotl.float32for the M matrix chain multiply. However, the non-mergedpre_process_fwd_bwd_kernel_stage2still usesb_w.dtype:b_m = tl.dot(b_m_i.to(b_w.dtype), b_m.to(b_w.dtype))While the production code paths use the merged kernels (lines 1050, 1120), the non-merged stage2 kernel is explicitly called from
benchmarks/cp/benchmark_chunk_delta_h_kernels.py(lines 142, 174, 304) for performance and correctness benchmarking. This inconsistency means benchmarks will measure the non-fixed version while production uses the fixed version, leading to inconsistent precision characteristics.Proposed fix
- b_m = tl.dot(b_m_i.to(b_w.dtype), b_m.to(b_w.dtype)) + b_m = tl.dot(b_m_i.to(tl.float32), b_m.to(tl.float32))
🤖 Fix all issues with AI agents
In `@fla/ops/cp/KCP.md`:
- Line 5: Fix the typo and sentence punctuation in the CP description: change
"introduce" to "introduced" in the sentence "CP was first introduce in PR
`#691`...", split the clause after the PR link into a new sentence by replacing
the comma before "Special thanks to" with a period, and capitalize "Special" so
it reads "Special thanks to [mdy666]...".
🧹 Nitpick comments (7)
tests/context_parallel/test_cp_conv.py (2)
113-113: Scaling input by 100 — intentional but worth a comment.
x_globalis multiplied by 100, which increases the dynamic range and stress-tests numerical precision. A brief inline comment explaining why (e.g., "amplify values to stress-test conv precision under CP splitting") would help future readers.
201-204: Ratio of 0.001 is quite tight — verify this passes reliably in CI.Using
ratio=0.001for all four checks (output, dx, dw, db) is a strict tolerance. Theassert_closehelper infla/utils.pywill warn (instead of fail) in CI when the ratio is under 0.01, but outside CI this will hard-assert. If convolution tests are flaky, consider whether gradient checks (dw, db) need a slightly relaxed ratio.fla/ops/cp/KCP.md (1)
11-14: Add language identifiers to pseudocode fenced blocks.Markdownlint flags multiple bare fenced code blocks (MD040). For pseudocode/math blocks, use a language identifier like
textormathto silence the lint warnings and improve rendering in some markdown processors. This applies to all ~11 unlabeled code blocks in this file.Example fix
-``` +```text S_t = decay(g_t) * S_{t-1} + beta_t * k_t (x) (v_t - S_{t-1} @ k_t) o_t = q_t^T @ S_t -``` +```tests/context_parallel/test_cp_gdn.py (1)
354-365: CP8 test requires 8 GPUs — likely won't run in most CI environments.This test with
T=65536and 8 GPUs is a comprehensive stress test but will be skipped in most CI setups. Consider documenting this as a manual/nightly-only test, or adding a marker (e.g.,@pytest.mark.slow).tests/context_parallel/test_cp_kda.py (3)
216-236: Imports inside worker function are fine for spawn, buttritonandmathcould be top-level.
import tritonandimport mathare placed inside the worker function body. While this works correctly withspawn(fresh process re-imports), these are stable stdlib/dependency imports that could live at the module level for clarity. This is a minor style point only.
427-442: Ratio of 5e-2 is relatively relaxed — consider documenting the expected tolerance.The KDA tests use
ratio=5e-2compared to GDN's2e-3and conv's0.001. This is 25-50× more relaxed. Given the complexity of KDA (per-dim gating, L2 normalization, gate computation), a wider tolerance may be justified, but a brief comment explaining why would help future maintainers understand this isn't just a "make-the-test-pass" number.
549-561: CP8 test with T=65536 is a heavy test — same observation as GDN CP8.Requires 8 GPUs and processes 65K tokens. Consider adding a
@pytest.mark.slowmarker.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Summary by CodeRabbit
Documentation
Improvements
Tests