Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-hexagon/htp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx)

if (_hmx_idx GREATER_EQUAL 0)
target_sources(${HTP_LIB} PRIVATE
hmx-queue.c
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
)

# -mhmx enables HMX instruction set (needed by files that include hmx-utils.h)
set_source_files_properties(
hmx-flash-attn-ops.c
hmx-matmul-ops.c
hmx-queue.c
PROPERTIES COMPILE_OPTIONS "-mhmx"
)

Expand Down
157 changes: 81 additions & 76 deletions ggml/src/ggml-hexagon/htp/flash-attn-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
// Must be multiple of 32
#define FLASH_ATTN_BLOCK_SIZE (32 * 2)

#if __HVX_ARCH__ < 79
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b))
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b))
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b)
#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b)
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif

// This is a bit of a hack because the compiler is strugling to properly inline
// the default hvx_vec_f32_to_f16 with output into the local array.
static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1)
Expand Down Expand Up @@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf);
}

HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)));
rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)));
HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p));
rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
hvx_vec_store_u(r, 4, rsum);
}

Expand Down Expand Up @@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y,
rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf);
}

HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)));
HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)));
HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)));
HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)));
HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p));
HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p));
HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p));
HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p));

HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } };
return hvx_vec_reduce_sum_f32x4(rsum0123);
Expand All @@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
const size_t nloe = n % VLEN_FP16; // leftover elements

HVX_Vector sums; // initialize at j = 0
HVX_Vector sums = Q6_V_vzero();
const size_t stride_x_4 = stride_x * 4;
for (uint32_t j = 0; j < VLEN_FP32; j += 4) {
HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe);
Expand All @@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y,
x += stride_x_4;
}

sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums);
return Q6_Vsf_equals_Vqf32(sums);
return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums);
}

// MAD: y (F32) += x (F16) * s (F16)
Expand Down Expand Up @@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t *
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; ++i) {
vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs));
vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v));
hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs));
}
}

Expand Down Expand Up @@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Process in sub-blocks of 32 (VLEN_FP32)
HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32];
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale);

// 2. Softcap
if (factx->logit_softcap != 0.0f) {
scores = hvx_vec_tanh_f32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap);
scores = Q6_Vsf_equals_Vqf32(scores);
scores = HVX_OP_MUL_F32(scores, logit_cap);
}

// 3. Mask
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp;
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores);
scores = Q6_Vsf_equals_Vqf32(scores);

// Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79.
// Clamp -INFINITY to the max negative fp16 finite value (-65504.0f).
HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00);
HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF);
HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf);
m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16);

#if __HVX_ARCH__ >= 79
HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vsf_vadd_VsfVsf(add_val, scores);
#else
HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec);
HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair);
scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores));
#endif
}

// Mask out invalid lanes for leftover handling
uint32_t valid_lanes = current_block_size - ic;
if (valid_lanes < VLEN_FP32) {
HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane
scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY));
}

sb_scores[iv] = scores;
Expand All @@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
{
// 4. Online Softmax Update
HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec);
HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec));
HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
M_vec = M_new_vec;

hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);

HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = sb_scores[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));
HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec);
HVX_Vector P = hvx_vec_exp_f32(scores_shifted);

p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));
p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P);

// 5. Accumulate V
__fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16];
hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0));

float __attribute__((aligned(128))) P_arr[VLEN_FP32];
hvx_vec_store_a(P_arr, 128, P);

for (uint32_t j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
const uint32_t cur_ic = ic2 + j;
if (cur_ic >= current_block_size) {
break;
}

if (cur_ic + 1 == current_block_size) {
// Odd leftover, process single row
if (P_arr[j] != 0.0f) {
const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV);
}
break;
}

// Avoid NaN * 0.0 = NaN for uninitialized V cache rows.
// Check the f32 values to safely avoid strict aliasing violations.
if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) {
continue;
}

const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded;
hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV);
}
}

p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec));
}

if (ic < current_block_size) {
// Sync scalars for leftover/next block if needed
float M = hvx_vec_get_f32(M_vec);
float S = hvx_vec_get_f32(S_vec);

// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded;
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale);
if (factx->logit_softcap != 0.0f) {
s_val = factx->logit_softcap * tanhf(s_val);
}

if (mask) {
const float m_val = m_base[ic];
s_val += slope * m_val;
}

const float Mold = M;
__fp16 vs = 1.0f;

if (s_val > M) {
M = s_val;
HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M);
HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec);
hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec);

float ms = hvx_vec_get_f32(ms_vec);
S = S * ms + vs;
} else {
HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M);
vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec));
S += vs;
}

const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded;

hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV);
}

M_vec = hvx_vec_splat_f32(M);
S_vec = hvx_vec_splat_f32(S);
S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec);
}

// Issue DMA for next+1 block (if exists)
Expand Down Expand Up @@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
const int i2 = iq2;
const int i3 = iq3;

// dst is permuted
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
// dst is permuted: [DV, n_heads, n_tokens, n_seq]
// head stride is nb[1], token stride is nb[2], batch stride is nb[3]
uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3];

if (dst->type == HTP_TYPE_F32) {
hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
Expand All @@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
}

#ifdef HTP_HAS_HMX
// HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) {
// HMX path: head_dim multiple of 32, F16 KV
if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) {
int ret = hmx_flash_attn_ext(octx);
if (ret == HTP_STATUS_OK) {
return ret;
Expand Down
3 changes: 0 additions & 3 deletions ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) {
if (DK % 32 != 0 || DV % 32 != 0) {
return HTP_STATUS_NO_SUPPORT;
}
if (neq1 < 32) {
return HTP_STATUS_NO_SUPPORT;
}

// GQA factor
const uint32_t n_kv_heads = k->ne[2];
Expand Down
Loading
Loading