Skip to content

Rtp llm refactor#462

Merged
valarLip merged 20 commits into
mainfrom
rtp_llm_refactor
May 29, 2025
Merged

Rtp llm refactor#462
valarLip merged 20 commits into
mainfrom
rtp_llm_refactor

Conversation

@fsx950223
Copy link
Copy Markdown
Contributor

No description provided.

@valarLip valarLip merged commit a43663f into main May 29, 2025
9 of 10 checks passed
@valarLip valarLip deleted the rtp_llm_refactor branch May 29, 2025 02:50
valarLip pushed a commit that referenced this pull request Mar 18, 2026
* update branch

* use __half replace _Float16

* use __half

* update cpp api

* add more conditions

* fix comment

* fix comment

* fix bugs

* update triton compiler

* format code

* add lint approach

* fix ruff

* fix ruff

* format code

* fix format

* format code
valarLip pushed a commit that referenced this pull request Mar 18, 2026
* update branch

* use __half replace _Float16

* use __half

* update cpp api

* add more conditions

* fix comment

* fix comment

* fix bugs

* update triton compiler

* format code

* add lint approach

* fix ruff

* fix ruff

* format code

* fix format

* format code
sunway513 added a commit to sunway513/aiter that referenced this pull request Apr 30, 2026
…R#462 cleanup)

Apply the same internal-types cleanup pattern as upstream FlyDSL PR ROCm#462
(coderfeli) to the gfx1201 flash_attn_func kernel. Replaces 110 of 116
raw MLIR dialect call sites (94.8% reduction) with FlyDSL public Numeric
wrappers and Vector helpers.

Changes by category:
- 11 arith.constant + 18 arith.index sites -> fx.Int32/fx.Float32/fx.Index
- 11 arith.index_cast sites -> fx.Index/fx.Int32/fx.Int64 wrappers
- 14 arith.AddFOp/SubFOp/MulFOp/MaxNumFOp + 8 arith.AddIOp/MulIOp etc
  -> local _fadd/_fsub/_fmul/_fmax helpers that preserve fastmath flag
  via lowercase arith.{addf,subf,mulf} + arith.MaxNumFOp
- 14 vector.load/store/extract/from_elements/bitcast/broadcast sites
  -> Vec.load / Vec(x).store / Vec(x)[i] / Vec.from_elements(...,dtype)
- 4 _llvm.GEPOp/LoadOp/StoreOp sites -> buffer_ops.get_element_ptr +
  _pointer_load/_pointer_store helpers
- 4 scf.IfOp + 1 scf.for_ + 5 scf.YieldOp sites -> Python `if cond:` and
  `for ... in range(0, upper, step, init=...)` natural form
- 4 _fly.extract_aligned_pointer_as_index sites -> _extract_aligned_pointer
  local helper
- arith.constant_vector x2 -> Vec.filled
- arith.trunc_f / arith.truncf -> Vec(x).to(elem_dtype) + fx.Float32(x).to(...)
- math_dialect.fma -> fmath.fma

Worked around the 4 skill-v1 gaps documented in P7-D's audit:
1. fx wrapper-vs-raw mismatch: every wrapper that flows into raw MLIR ops
   is unwrapped via PR462's _to_raw helper (imported as _raw)
2. fx.Vec does not exist: use `from flydsl.expr.typing import Vector as Vec`
3. fastmath cannot be dropped: helpers preserve fastmath=fm_fast everywhere
4. scf.IfOp restructure is manual: the 16-yield CAUSAL mask block is
   unfolded SSA-style (PR462 lines 700-870 pattern)

Two bytecode-preservation tricks beyond PR462's template:
- Explicit arith.cmpi(slt) for q_in_bounds — fx.Index < operator defaults
  to unsigned compare, which would emit v_cmp_gt_u64_e64 vs baseline
  v_cmp_gt_i64_e64 and break ASM equality
- aiter's dtype_to_elem_type returns raw ir.Type while Vec.make_type needs
  a Numeric class — added local _NUMERIC_MAP = {f32: fx.Float32, ...}

Final ASM verification on R9600D (gfx1201) wan-best container:
  baseline ASM SHA256 (PR head affebbe, kernel f049714d):
    4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997
  refactor ASM SHA256:
    4b3c45f65556324e86d8182613efa7cc9fb164adbfcc9eb0bf17ac208f775997
  -> BYTE-EQUAL. All 22 IR pipeline stages produce identical final ISA.

Perf: 49.68 TFLOPS mid3 mean (5 runs, S=32768 H=12 D=128 bf16 noncausal),
baseline 49.84, ratio 99.7% (within bf16 noise floor).

Tests: op_tests/flydsl_tests/test_flydsl_fmha.py — 10 passed, 2 skipped
(multi-GPU only, same as baseline). Includes both causal=False (Wan2.1
production hot path) and causal=True coverage.

Lint: black + ruff both clean.

Diffstat: 447 lines refactored (+255 / -192). Remaining 6 raw MLIR call
sites are intentional and isolated to helper functions that map 1:1 to
PR ROCm#462 upstream:
  arith.cmpi(slt)  — preserve signed compare for ISA hash equality
  arith.MaxNumFOp  — inside _fmax helper to preserve fastmath flag
  _llvm.LoadOp/StoreOp — inside _pointer_{load,store} helpers
  _memref.load     — scalar element load with no Vec equivalent
  _fly.extract_aligned_pointer_as_index — inside _extract_aligned_pointer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants