Skip to content
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
8b7ce0b
Add power management utilities to NPU device context and update DCVS …
chraac Aug 2, 2025
9a3cf62
Update DCVS settings in power_utils to use v3 API and enhance power m…
chraac Aug 2, 2025
5be7b0a
wip
chraac Aug 3, 2025
39d0b70
Enhance dequantization functions by adding load_dequant_table support…
chraac Aug 4, 2025
9063fd3
use lut
chraac Aug 4, 2025
9cf2c43
wip
chraac Aug 4, 2025
ccdf858
fix test failure
chraac Aug 4, 2025
94f7022
wip
chraac Aug 4, 2025
50add7e
Refactor load_qual_block_generic to improve block handling and optimi…
chraac Aug 5, 2025
df55391
Enhance load_dual_block_generic and load_qual_block_generic to accept…
chraac Aug 5, 2025
c60433f
Refactor flash_attn_impl to optimize mask l2 prefetch
chraac Aug 6, 2025
0ad08cc
wip
chraac Aug 6, 2025
c502b4c
wip
chraac Aug 6, 2025
c8985f5
wip
chraac Aug 6, 2025
a600676
wip
chraac Aug 7, 2025
3048e3e
add log
chraac Aug 7, 2025
669faa0
link against shared libraries instead of static ones
chraac Aug 7, 2025
85a082f
fix swiglu
chraac Aug 7, 2025
e7ceb25
wip
chraac Aug 7, 2025
4601d7f
refactor expf_fix to handle overflow for different data types
chraac Aug 7, 2025
20a4ed2
enhance is_glu_op_supported to validate shapes for multiple sources
chraac Aug 7, 2025
e56b2c1
wip
chraac Aug 7, 2025
723e04a
refactor logging macros to use hexagon namespace and improve formatting
chraac Aug 8, 2025
409cb28
fix printf format error
chraac Aug 8, 2025
3802041
wip
chraac Aug 8, 2025
c3771a8
refactor: update static_assert messages for block size validation and…
chraac Aug 8, 2025
e469b94
rename
chraac Aug 8, 2025
0722d20
Merge branch 'dev-refactoring' into dev-quant-lut
chraac Aug 9, 2025
ba8c044
feat: enhance fa with mask
chraac Aug 9, 2025
6d1f5a8
wip
chraac Aug 9, 2025
be4202c
wip
chraac Aug 9, 2025
d7dc3df
refactor: replace instances of Q6_V_vzero() with kZeroV for consistency
chraac Aug 9, 2025
5b27dc6
wip
chraac Aug 9, 2025
09fda2c
wip
chraac Aug 9, 2025
bd2089b
wip
chraac Aug 9, 2025
08be69d
fix: improve address alignment check in HVX_Vector handling
chraac Aug 10, 2025
eff7ce7
refactor: streamline vector dot product implementations for improved …
chraac Aug 10, 2025
9f94164
refactor: q4k add hvx intrinsic impl
chraac Aug 10, 2025
788eb85
refactor: enhance dequantize_row_q4_K for clarity and performance
chraac Aug 11, 2025
f8897b6
refactor: optimize scale mask usage in dequantization functions for i…
chraac Aug 11, 2025
082d666
refactor: optimize dequantize_row_q4_K for intrinsic usage and perfor…
chraac Aug 11, 2025
31c53c4
refactor: move GLU operation implementation into separated file
chraac Aug 11, 2025
764273c
sync after swiglu
chraac Aug 11, 2025
0040711
wip
chraac Aug 12, 2025
380bd8f
wip
chraac Aug 12, 2025
00b9da9
wip
chraac Aug 12, 2025
4f8cf2b
feat: increase prc main thread stack size
chraac Aug 12, 2025
c3ba43b
fix: replace hardcoded stack size with NPU_THREAD_STACK_SIZE constant
chraac Aug 12, 2025
86d9c93
wip
chraac Aug 13, 2025
d33587e
feat: add optimized vector operations for exponential and division wi…
chraac Aug 13, 2025
dd58a98
wip
chraac Aug 13, 2025
c279a3d
feat: refactor exponential function to handle overflow and underflow …
chraac Aug 13, 2025
8c9b5ef
wip
chraac Aug 13, 2025
bce6fd4
wip
chraac Aug 15, 2025
1d0bca6
feat: add vector loading and scaling functions for improved performan…
chraac Aug 15, 2025
7a0cd2f
wip
chraac Aug 15, 2025
a318bba
feat: optimize block loading by refactoring scale index handling for …
chraac Aug 15, 2025
818baa5
use Q6_Vb_vlut32_VbVbR_nomatch instead
chraac Aug 15, 2025
f7c1b7c
feat: enhance scale loading by adding static assertion and restructur…
chraac Aug 15, 2025
cd349ce
wip
chraac Aug 16, 2025
027a933
feat: refactor vec_dot_product_mixed_impl for improved clarity and pe…
chraac Aug 16, 2025
eeb4606
wip
chraac Aug 16, 2025
20fb6c5
feat: simplify vector loading functions and improve alignment handling
chraac Aug 17, 2025
6c3bc2d
wip
chraac Aug 17, 2025
3694d50
feat: enhance scale loading mask with quantization block size validation
chraac Aug 17, 2025
bdbf172
wip
chraac Aug 17, 2025
423acb7
feat: implement make_scale_load_mask function and refactor vector han…
chraac Aug 17, 2025
f9cc060
feat: enhance load_dual_block_generic to include scale indices for im…
chraac Aug 17, 2025
36f1870
revert q8 dequant
chraac Aug 17, 2025
9bba483
wip
chraac Aug 17, 2025
f0ca3e7
feat: optimize dequantization functions by removing unnecessary maski…
chraac Aug 17, 2025
9901ca0
wip
chraac Aug 18, 2025
38935b6
Merge branch 'dev-refactoring' into dev-quant-lut
chraac Aug 25, 2025
e97b3c0
wip
chraac Aug 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ggml/src/ggml-qnn/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ else()
target_compile_options(hexagon_npu_skel_OBJS PUBLIC
-fsanitize=address -fno-omit-frame-pointer
)
target_link_libraries(hexagon_npu_skel_OBJS PUBLIC
target_link_options(hexagon_npu_skel_OBJS PUBLIC
-fsanitize=address
)
endif()
Expand Down Expand Up @@ -248,9 +248,9 @@ else()

add_library(hexagon_npu_skel SHARED $<TARGET_OBJECTS:hexagon_npu_skel_OBJS>)
target_link_libraries(hexagon_npu_skel
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc++abi.a
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc++.a
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc.a
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc++abi.so.1
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc++.so.1
${HEXAGON_LIB_DIR}/${HEXAGON_ARCH}/G0/pic/libc.so
)
set_target_properties(hexagon_npu_skel PROPERTIES OUTPUT_NAME "hexagon_npu_skel_${HEXAGON_ARCH}")
target_link_libraries(hexagon_npu_skel qprintf_static)
Expand Down
43 changes: 26 additions & 17 deletions ggml/src/ggml-qnn/npu/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,30 @@
namespace {

struct npu_device_context {
std::unique_ptr<hexagon::power_utils> power_utils; // Power management utilities
std::unique_ptr<hexagon::default_thread_pool> thread_pool;
std::unique_ptr<float[]> f16_to_f32_table; // TODO: store vtcm?

bool init() {
if (!init_ltu()) {
DEVICE_LOG_ERROR("Failed to initialize LTU");
DEVICE_LOG_ERROR("Failed to initialize LTU\n");
return false;
}

if (!init_thread_pool()) {
DEVICE_LOG_ERROR("Failed to initialize thread pool");
DEVICE_LOG_ERROR("Failed to initialize thread pool\n");
return false;
}

DEVICE_LOG_DEBUG("NPU device context initialized");
power_utils = std::make_unique<hexagon::power_utils>();
if (power_utils && power_utils->is_valid()) {
power_utils->set_dvcs_performance_mode(true);
DEVICE_LOG_DEBUG("Power utilities initialized with DVCS performance mode enabled\n");
} else {
DEVICE_LOG_ERROR("Failed to initialize power utilities\n");
}

DEVICE_LOG_DEBUG("NPU device context initialized\n");
return true;
}

Expand All @@ -41,29 +50,29 @@ struct npu_device_context {

f16_to_f32_table = std::make_unique<float[]>(kLtuCount);
if (!f16_to_f32_table) {
DEVICE_LOG_ERROR("Failed to allocate memory for f16_to_f32 table");
DEVICE_LOG_ERROR("Failed to allocate memory for f16_to_f32 table\n");
return false;
}

hexagon::init_f16_f32_table(f16_to_f32_table.get(), kLtuCount);
DEVICE_LOG_DEBUG("f16_to_f32 table initialized");
DEVICE_LOG_DEBUG("f16_to_f32 table initialized\n");
return true;
}

bool init_thread_pool() {
if (thread_pool) {
DEVICE_LOG_DEBUG("Thread pool already initialized");
DEVICE_LOG_DEBUG("Thread pool already initialized\n");
return true;
}

auto pool = std::make_unique<hexagon::default_thread_pool>();
if (!pool) {
DEVICE_LOG_ERROR("Failed to create thread pool");
DEVICE_LOG_ERROR("Failed to create thread pool\n");
return false;
}

thread_pool = std::move(pool);
DEVICE_LOG_DEBUG("Thread pool initialized");
DEVICE_LOG_DEBUG("Thread pool initialized\n");
return true;
}
};
Expand Down Expand Up @@ -102,25 +111,25 @@ int npu_device_open(const char * uri, remote_handle64 * h) {
// TODO: should we have a device context here?
auto * context = new npu_device_context();
if (!context->init()) {
DEVICE_LOG_ERROR("Failed to initialize npu_device_context");
DEVICE_LOG_ERROR("Failed to initialize npu_device_context\n");
delete context;
return AEE_EFAILED;
}

*h = reinterpret_cast<remote_handle64>(context);
DEVICE_LOG_INFO("NPU device context created: %p", (void *) *h);
DEVICE_LOG_INFO("NPU device context created: %p\n", (void *) *h);
return AEE_SUCCESS;
}

int npu_device_close(remote_handle64 h) {
auto * context = device_context_from_handle(h);
if (!context) {
DEVICE_LOG_ERROR("Invalid npu_device_context handle");
DEVICE_LOG_ERROR("Invalid npu_device_context handle\n");
return AEE_EINVHANDLE;
}

delete context;
DEVICE_LOG_INFO("NPU device context destroyed: %p", (void *) h);
DEVICE_LOG_INFO("NPU device context destroyed: %p\n", (void *) h);
return AEE_SUCCESS;
}

Expand All @@ -139,7 +148,7 @@ AEEResult npu_device_device_support_op(remote_handle64 _h,
NPU_UNUSED(_h);

if (!srcs || srcsLen <= 0 || !dst || !is_supported) {
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments");
DEVICE_LOG_ERROR("npu_device_device_support_op: Invalid arguments\n");
return AEE_EINVARGS;
}

Expand Down Expand Up @@ -185,7 +194,7 @@ AEEResult npu_device_tensors_free(remote_handle64 _h,
int tensor_handlesLen) {
NPU_UNUSED(_h);
if (!tensor_handles || tensor_handlesLen < 0) {
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments");
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid arguments\n");
return AEE_EINVARGS;
}

Expand All @@ -194,7 +203,7 @@ AEEResult npu_device_tensors_free(remote_handle64 _h,
if (tensor) {
delete tensor;
} else {
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d", i);
DEVICE_LOG_ERROR("npu_device_tensors_free: Invalid tensor handle at index %d\n", i);
}
}

Expand Down Expand Up @@ -250,13 +259,13 @@ AEEResult npu_device_graph_set_tensor_with_param(remote_handle64
AEEResult npu_device_graph_compute(remote_handle64 _h, npu_device_graph_handle_t graph_handle) {
auto dev_ctx = device_context_from_handle(_h);
if (!dev_ctx) {
DEVICE_LOG_DEBUG("Invalid npu_device_context handle");
DEVICE_LOG_DEBUG("Invalid npu_device_context handle\n");
return AEE_EINVHANDLE;
}

auto * graph = graph_from_handle(graph_handle);
if (!graph) {
DEVICE_LOG_ERROR("Invalid graph handle");
DEVICE_LOG_ERROR("Invalid graph handle\n");
return AEE_EINVHANDLE;
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-qnn/npu/device/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void graph::compute_impl(default_thread_pool * pool, default_thread_pool::thread

const bool should_sync = requires_thread_barrier(op);
if (pool && should_sync && i < _tensor_count - 1) {
// For the last tensor, the thread pool will handle synchronization
DEVICE_SCOPED_PERFORMANCE_TRACKER("[%p]sync_thread, tidx: %zu, tensor[%zu/%zu]",
(void *) this,
params.get_thread_index(),
Expand Down
67 changes: 43 additions & 24 deletions ggml/src/ggml-qnn/npu/device/op_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ inline float f16_to_f32(const npu_device_fp16_t src) {
}

// From: ggml/src/ggml-cpu/ops.cpp
template <bool _IsKvF16>
template <bool _IsKvF16, bool _HasMask>
void flash_attn_impl(hexagon::tensor * out,
const hexagon::tensor * q,
const hexagon::tensor * k,
Expand All @@ -24,6 +24,7 @@ void flash_attn_impl(hexagon::tensor * out,
static_assert(3 <= hexagon::kMaxParamsCount, "flash_attn op params count exceeds max params count");

constexpr const npu_device_tensor_data_type kKvDataType = _IsKvF16 ? NPU_DATA_TYPE_F16 : NPU_DATA_TYPE_F32;
constexpr const bool kHasMask = _HasMask;

if (k->get_type() != kKvDataType || v->get_type() != k->get_type()) {
DEVICE_LOG_ERROR("flash_attn_impl: k and v must have same type, got k: %s, v: %s\n",
Expand All @@ -32,6 +33,11 @@ void flash_attn_impl(hexagon::tensor * out,
return;
}

if (kHasMask != (mask != nullptr)) {
DEVICE_LOG_ERROR("flash_attn_impl: mask is required when kHasMask is true\n");
return;
}

float scale = out->get_op_param<float>(0);
const float max_bias = out->get_op_param<float>(1);
const float logit_softcap = out->get_op_param<float>(2);
Expand Down Expand Up @@ -96,7 +102,7 @@ void flash_attn_impl(hexagon::tensor * out,
const uint8_t * q_ptr = q->get_read_buffer();
const uint8_t * k_ptr = k->get_read_buffer();
const uint8_t * v_ptr = v->get_read_buffer();
const uint8_t * mask_ptr = mask ? mask->get_read_buffer() : nullptr;
const uint8_t * mask_ptr = kHasMask ? mask->get_read_buffer() : nullptr;
const uint8_t * sinks_ptr = sinks ? sinks->get_read_buffer() : nullptr;
float * VKQ32 = reinterpret_cast<float *>(cache_ptr); // FP32 VKQ accumulator
auto * VKQ16 = reinterpret_cast<npu_device_fp16_t *>(VKQ32 + aligned_dv); // (temporary) FP16 VKQ accumulator
Expand Down Expand Up @@ -125,11 +131,17 @@ void flash_attn_impl(hexagon::tensor * out,
}

const npu_device_fp16_t * mp =
mask_ptr ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
kHasMask ? reinterpret_cast<const npu_device_fp16_t *>(mask_ptr + iq1 * mask->get_nb(1) +
(iq2 % mask->get_ne(2)) * mask->get_nb(2) +
(iq3 % mask->get_ne(3)) * mask->get_nb(3)) :
nullptr;

q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);

if (kHasMask) {
hexagon::l2fetch_row(reinterpret_cast<const uint8_t *>(mp), mask->get_nb(1));
}

// k indices
const int ik3 = iq3 / rk3;
const int ik2 = iq2 / rk2;
Expand All @@ -138,16 +150,14 @@ void flash_attn_impl(hexagon::tensor * out,
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

q_to_vec_dot(reinterpret_cast<const float *>(q_data), Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
const auto * k_plane_ptr = k_ptr + ik2 * k->get_nb(2) + ik3 * k->get_nb(3);
const auto * v_plane_ptr = v_ptr + iv2 * v->get_nb(2) + iv3 * v->get_nb(3);
for (int64_t ic = 0; ic < k->get_ne(1); ++ic) {
DEVICE_SCOPED_OP_PERFORMANCE_TRACKER_ADD_ONE_SUB_PROC(flash_attn, 0, loop);
float mv = mp ? (slope * f16_to_f32(mp[ic])) : 0.0f;
float mv = kHasMask ? (slope * f16_to_f32(mp[ic])) : 0.0f;
if (mv == -INFINITY) {
continue;
}
Expand Down Expand Up @@ -282,9 +292,17 @@ bool flash_attn_f32(tensor * out, compute_params * params) {
const auto * mask = out->get_src(3);
const auto * sinks = out->get_src(4);
if (k->get_type() == NPU_DATA_TYPE_F16) {
flash_attn_impl<true>(out, q, k, v, mask, sinks, params);
if (mask) {
flash_attn_impl<true, true>(out, q, k, v, mask, sinks, params);
} else {
flash_attn_impl<true, false>(out, q, k, v, mask, sinks, params);
}
} else {
flash_attn_impl<false>(out, q, k, v, mask, sinks, params);
if (mask) {
flash_attn_impl<false, true>(out, q, k, v, mask, sinks, params);
} else {
flash_attn_impl<false, false>(out, q, k, v, mask, sinks, params);
}
}
return true;
}
Expand Down Expand Up @@ -338,8 +356,8 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,

if (dst->ne[0] != v->ne[0] || dst->ne[2] != q->ne[1]) {
DEVICE_LOG_DEBUG(
"[%s]dst shape does not match q and v: dst ne: %ld, %ld, %ld, %ld, q ne: %ld, %ld, %ld, %ld, "
"v ne: %ld, %ld, %ld, %ld\n",
"[%s]dst shape does not match q and v: dst ne: %lld, %lld, %lld, %lld, q ne: %lld, %lld, %lld, %lld, "
"v ne: %lld, %lld, %lld, %lld\n",
op_get_name(op),
dst->ne[0],
dst->ne[1],
Expand All @@ -359,24 +377,25 @@ bool is_flash_attn_supported(const npu_device_tensor_op_spec * op_spec,
if (is_transposed_or_permuted(dst->nb)) {
DEVICE_LOG_DEBUG("[%s]dst cannot be transposed or permuted, nb: %zu, %zu, %zu, %zu\n",
op_get_name(op),
dst->nb[0],
dst->nb[1],
dst->nb[2],
dst->nb[3]);
(size_t) dst->nb[0],
(size_t) dst->nb[1],
(size_t) dst->nb[2],
(size_t) dst->nb[3]);
return false;
}

if (q->ne[0] != k->ne[0]) {
DEVICE_LOG_DEBUG("[%s]q and k shapes do not match: q ne: %ld, %ld, %ld, %ld, k ne: %ld, %ld, %ld, %ld\n",
op_get_name(op),
q->ne[0],
q->ne[1],
q->ne[2],
q->ne[3],
k->ne[0],
k->ne[1],
k->ne[2],
k->ne[3]);
DEVICE_LOG_DEBUG(
"[%s]q and k shapes do not match: q ne: %lld, %lld, %lld, %lld, k ne: %lld, %lld, %lld, %lld\n",
op_get_name(op),
q->ne[0],
q->ne[1],
q->ne[2],
q->ne[3],
k->ne[0],
k->ne[1],
k->ne[2],
k->ne[3]);
return false;
}

Expand Down
Loading
Loading