From 0ce8b1e9da82f839ae95bfa89427af6d9dd557d6 Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Fri, 19 Sep 2025 16:46:17 -0700 Subject: [PATCH] Use weight cache for quantized tensor scale data (#14448) Summary: When enabling the XNNPACK weight cache and running a model with qb4 or qc8-quantized linear weights, it triggers an assertion that is intended to make sure all data is in the weight cache. This can be reproduced by running the XNNPACK backend linear op tests with weight cache enabled. The root cause appears to be that tensor scale data was bypassing the weight cache - likely an oversight in initial implementation. This isn't a correctness issue, but does cause the aforementioned assert to fail and uses marginally more memory than it otherwise needs to. This PR updates the XNNPACK compileModel call to use the weight cache for scale data (instead of putting it in the unpacked_buffers list). With this change, the linear op tests pass with weight cache enabled. Test Plan: ``` buck test -c executorch.xnnpack_weights_cache=1 fbcode//executorch/backends/xnnpack/test:test_xnnpack_ops -- linear ``` Reviewed By: lucylq, digantdesai Differential Revision: D82862629 Pulled By: GregoryComer --- backends/xnnpack/runtime/XNNCompiler.cpp | 65 ++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 78eaaf6d039..1ed7db80d84 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant data associated with the tensor value, then returns nullptr. */ const uint8_t* getConstantDataPtr( - const fb_xnnpack::XNNTensorValue* tensor_value, + uint32_t buffer_idx, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, const NamedDataMap* named_data_map, std::vector& freeable_buffers, XNNWeightsCache* weights_cache) { - auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx) { if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC @@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr( return nullptr; } +const uint8_t* getConstantDataPtr( + const fb_xnnpack::XNNTensorValue* tensor_value, + GraphPtr flatbuffer_graph, + const uint8_t* constant_data_ptr, + const NamedDataMap* named_data_map, + std::vector& freeable_buffers, + XNNWeightsCache* weights_cache) { + return getConstantDataPtr( + tensor_value->constant_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache); +} + /** Define serialized tensor value into the subgraph. While also keeping track of the remapped ids from @@ -434,22 +449,15 @@ Error defineTensor( const float* scale = qparams->scale()->data(); if (qparams->scale_buffer_idx() != 0) { - // if scales are stored in named data, then retrieve it - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale = reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale != nullptr, Internal, "Failed to load scale data."); } status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, @@ -483,22 +491,15 @@ Error defineTensor( // Block scales are preferably serialized as bf16 but can also be // serialized as fp32 for backwards compatability. if (qparams->scale_buffer_idx() != 0) { - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale_data = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale_data = - reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale_data != nullptr, Internal, "Failed to load scale data."); scale_numel = qparams->num_scales(); } else { // Read fp32 scales, convert to bf16.