Skip to content

Question about the cost function in persistent kernel scheduler #1175

@Edenzzzz

Description

@Edenzzzz

I noticed the current cost function models a linear computation load. Since attention has quadratic complexity, I tried using a quadratic workload estimation function as below, and benchmarked it, but it leads to lower throughput in all cases.
Could you help me understand why this may be the case? @yzh119
Thanks.

inline float cost_function(int qo_len, int kv_len) {
  return float(qo_len) * kv_len; // all kv are multipled by qo_len, then apply mask
    // if (kv_len >= qo_len) {
    //     // Right-aligned case: prefix + triangular
    //     int kv_init = kv_len - qo_len;
    //     float shared = float(qo_len) * float(kv_init);
    //     float tri    = float(qo_len) * float(qo_len + 1) / 2.0f;
    //     return shared + tri;
    // } else {
    //     // kv_len < qo_len: classic causal sum
    //     float tri = float(kv_len) * float(kv_len + 1) / 2.0f;
    //     float rem = float(qo_len - kv_len) * float(kv_len);
    //     return tri + rem;
    // }
}

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions