Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -21790,7 +21790,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const struct ggml_tensor * k = node->src[1];
if (q->ne[1] == 1 && q->ne[3] == 1 && q->ne[2]/k->ne[2] > 1 && n_tasks > 1 && k->ne[1]/32 > 1) {
if (k->ne[2] > 1) {
int nk = 32 * (k->ne[2]*k->ne[1]/(32*n_tasks));
int nk = MAX(1, 32 * (k->ne[2]*k->ne[1]/(32*n_tasks)));
int nstep_k = k->ne[2]*k->ne[1]/nk;
size_t result_size = (Dv + 16)*q->ne[2]/k->ne[2]*sizeof(float);
size_t size = nstep_k*result_size;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/iqk/iqk_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
}

if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
int nk = 32 * (nek2*nek1/(32*nth));
int nk = std::max(1, 32 * (nek2*nek1/(32*nth)));
int nkk = (nek1 + nk - 1)/nk;
int nstep_k = nek2*nkk;
auto result_size = (Dv + 16)*rk2*sizeof(float);
Expand Down