Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 29 additions & 28 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <algorithm>
#include <cstddef>
#include <functional>
#include <array>
#include <vector>

#include <unordered_map>

Expand Down Expand Up @@ -467,29 +469,15 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s

// pad_ptr must be at least 'ci' floats for padding pixels.
// Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct.
//
// The pad buffer contents are always zero. Since the buffer is grow-only and never written with non-zero data,
// we only need to zero-initialize newly-grown elements.
thread_local std::vector<float> pad_ptr;
const float* old_pad_ptr = pad_ptr.data();
bool has_pad_ptr_changed = false;

if (pad_ptr.size() < padsize) {
pad_ptr.resize(padsize, 0.f);
if (pad_ptr.data() != old_pad_ptr) {
has_pad_ptr_changed = true;
}
} else {
// Ensure any previously-used region remains zeroed (grow-only means it should already be zeros,
// but keep this explicit for safety).
std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f);
}

LhsCacheKey key = {
ci, ih, iw,
padding, sh, sw,
kh, kw,
1, 1,
HashWeights(in)
};

//create lhs in format required for imatmul
const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);

Expand All @@ -498,18 +486,31 @@ static std::unique_ptr<std::byte[]> LhsPackImageDataSme(const size_t ci, const s

auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool);

// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
thread_local std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>> lhs_ptrs_cache;

if (has_pad_ptr_changed)
{
// If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they
// would be referencing the old pad buffer.
// See discussion in https://github.com/microsoft/onnxruntime/pull/27214.
// TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key
// or any other approach that would reduce unnecessary cache invalidations.
lhs_ptrs_cache.clear();
// Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions.
//
// Entries include pointers to the pad buffer for out-of-bounds pixels, so we must not reuse entries after the
// pad buffer is reallocated. To avoid clearing the entire cache, we group caches by pad buffer identity and
// invalidate only the old group when the pad buffer moves.
using LhsPtrsCache = std::unordered_map<LhsCacheKey, std::shared_ptr<const void*[]>>;
thread_local std::unordered_map<const float*, LhsPtrsCache> lhs_ptrs_cache_by_pad;

// If pad_ptr moved (vector reallocation), drop only the old group to avoid accumulating unreachable entries.
thread_local const float* last_pad_ptr = nullptr;
const float* cur_pad_ptr = pad_ptr.data();
if (last_pad_ptr != nullptr && last_pad_ptr != cur_pad_ptr) {
lhs_ptrs_cache_by_pad.erase(last_pad_ptr);
}
last_pad_ptr = cur_pad_ptr;

LhsCacheKey key = {
ci, ih, iw,
padding, sh, sw,
kh, kw,
1, 1,
HashWeights(in)
};

auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr];

std::shared_ptr<const void*[]> lhs_ptrs;
if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@



#include <array>
#include <cstddef>
#include <vector>

#include "mlasi_kleidiai.h"

#include "kai_ukernel_interface.h"
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
// SPDX-License-Identifier: MIT
//

#include <vector>
#include <cstring>
#include <cstddef>

#include "mlas.h"

#include "mlasi_kleidiai.h"

#include "kai_ukernel_interface.h"


#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h"
Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/mlas/unittest/test_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,24 @@ class MlasConv2DTest : public MlasTestBase {
}
}
}

//
// Regression test: exercise a KleidiAI Conv2D path when KleidiAI is enabled.
// See https://github.com/microsoft/onnxruntime/issues/26669.
//
// The KleidiAI implementation uses an internal per-thread padding buffer for out-of-bounds pixels
// when constructing the LHS indirection table. Historically, if the buffer was too small for a later
// convolution (larger CI), resizing could invalidate cached indirection pointers and lead to
// non-deterministic corruption.
//
// This sequence forces pad-buffer growth by running a smaller-CI convolution followed by a larger-CI
// convolution (with padding to ensure pad pointers are used), then runs the smaller-CI convolution again.
// Repeat a few times to increase the likelihood of triggering a reallocation and verify the path.
//
for (int i = 0; i < 4; ++i) {
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // smaller CI
Test(1, 1, 320, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // larger CI forces pad buffer growth
Test(1, 1, 64, 11, 11, 32, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1); // sanity: back to smaller CI after growth
}
}
};
Loading