-
Notifications
You must be signed in to change notification settings - Fork 577
Open
Labels
Description
I noticed the current cost function models a linear computation load. Since attention has quadratic complexity, I tried using a quadratic workload estimation function as below, and benchmarked it, but it leads to lower throughput in all cases.
Could you help me understand why this may be the case? @yzh119
Thanks.
inline float cost_function(int qo_len, int kv_len) {
return float(qo_len) * kv_len; // all kv are multipled by qo_len, then apply mask
// if (kv_len >= qo_len) {
// // Right-aligned case: prefix + triangular
// int kv_init = kv_len - qo_len;
// float shared = float(qo_len) * float(kv_init);
// float tri = float(qo_len) * float(qo_len + 1) / 2.0f;
// return shared + tri;
// } else {
// // kv_len < qo_len: classic causal sum
// float tri = float(kv_len) * float(kv_len + 1) / 2.0f;
// float rem = float(qo_len - kv_len) * float(kv_len);
// return tri + rem;
// }
}