Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ggml/src/ggml-hexagon/htp/flash-attn-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *

const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap);

dma_cache m_cache;
dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE);

for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &factx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1);
Expand Down Expand Up @@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}

// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
Expand Down Expand Up @@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void *
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}

// FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u",
Expand Down Expand Up @@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = factx.size_k_block * 2;
octx->src2_spad.size_per_thread = factx.size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0;
octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;

octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
Expand All @@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) {
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;

// FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread);

if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads);
}
Expand Down
75 changes: 64 additions & 11 deletions ggml/src/ggml-hexagon/htp/hex-dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,20 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t
desc->desc_size = 0; // 1D mode
desc->src_bypass = dma_src_l2_bypass_on;
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->order = 1;
desc->order = 0;
desc->done = 0;
desc->src = (void *) dptr.src;
desc->dst = (void *) dptr.dst;
desc->size = size;

q->dptr[q->push_idx] = dptr;

dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
if (size) {
dmlink(q->tail, desc);
q->tail = (dma_descriptor_2d *) desc;
} else {
desc->done = 1;
}

// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
Expand All @@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t
desc->dst_bypass = dma_dst_l2_bypass_on;
desc->src_comp = 0;
desc->dst_comp = 0;
desc->order = 1;
desc->order = 0;
desc->done = 0;
desc->src_stride = src_stride;
desc->dst_stride = dst_stride;
Expand All @@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t

q->dptr[q->push_idx] = dptr;

dmlink(q->tail, desc);
q->tail = desc;
if (nrows) {
dmlink(q->tail, desc);
q->tail = desc;
} else {
desc->done = 1;
}

// FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src);
q->push_idx = (q->push_idx + 1) & q->idx_mask;
Expand All @@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) {
dma_descriptor_2d * desc = &q->desc[q->pop_idx];

// Wait for desc to complete
while (1) {
dmpoll();
if (desc->done) {
break;
}
while (!desc->done) {
// FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx);
dmpoll();
}

dptr = q->dptr[q->pop_idx];
Expand Down Expand Up @@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_
return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows);
}

#define DMA_CACHE_MAX_SIZE 64U

typedef struct {
uint8_t *base;
uint32_t line_size;
uint32_t capacity;
uint32_t src[DMA_CACHE_MAX_SIZE];
uint16_t age[DMA_CACHE_MAX_SIZE];
} dma_cache;

static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity)
{
c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity;
c->base = base;
c->line_size = line_size;

for (unsigned i=0; i < c->capacity; i++) {
c->src[i] = 0;
c->age[i] = 0;
}
}

static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows)
{
uint32_t o_idx = 0;
uint16_t o_age = 0;
uint8_t * dst = 0;

for (unsigned i=0; i < c->capacity; i++) {
if (c->src[i] == (uint32_t) src) {
c->age[i] = 0;
dst = c->base + (i * c->line_size); nrows = 0; // dummy dma
// FARF(ERROR, "dma-cache: found %p", src);
} else {
c->age[i]++;
if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; }
}
}
if (!dst) {
// FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src);
c->age[o_idx] = 0;
c->src[o_idx] = (uint32_t) src;
dst = c->base + o_idx * c->line_size; // normal nrows dma
}

return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows);
}

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/ggml-hexagon/htp/rope-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) {
// (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start));
}

// Skip DMA transactions from prev block (if any)
// No need to wait for these since the DMA is setup for in-order processing
// Skip output DMA transactions from prev block (if any)
// No need to wait for those here since we're explicitly waiting for the latest prefecthes below.
for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); }

// Compute loop
Expand Down
Loading