@@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant
174174data associated with the tensor value, then returns nullptr.
175175*/
176176const uint8_t * getConstantDataPtr (
177- const fb_xnnpack::XNNTensorValue* tensor_value ,
177+ uint32_t buffer_idx ,
178178 GraphPtr flatbuffer_graph,
179179 const uint8_t * constant_data_ptr,
180180 const NamedDataMap* named_data_map,
181181 std::vector<FreeableBuffer>& freeable_buffers,
182182 XNNWeightsCache* weights_cache) {
183- auto buffer_idx = tensor_value->constant_buffer_idx ();
184183 if (buffer_idx) {
185184 if (!constant_data_ptr) {
186185 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
@@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr(
230229 return nullptr ;
231230}
232231
232+ const uint8_t * getConstantDataPtr (
233+ const fb_xnnpack::XNNTensorValue* tensor_value,
234+ GraphPtr flatbuffer_graph,
235+ const uint8_t * constant_data_ptr,
236+ const NamedDataMap* named_data_map,
237+ std::vector<FreeableBuffer>& freeable_buffers,
238+ XNNWeightsCache* weights_cache) {
239+ return getConstantDataPtr (
240+ tensor_value->constant_buffer_idx (),
241+ flatbuffer_graph,
242+ constant_data_ptr,
243+ named_data_map,
244+ freeable_buffers,
245+ weights_cache);
246+ }
247+
233248/* *
234249Define serialized tensor value into
235250the subgraph. While also keeping track of the remapped ids from
@@ -434,22 +449,15 @@ Error defineTensor(
434449 const float * scale = qparams->scale ()->data ();
435450
436451 if (qparams->scale_buffer_idx () != 0 ) {
437- // if scales are stored in named data, then retrieve it
438- ConstantDataOffsetPtr scale_buffer_offset =
439- flatbuffer_graph->constant_data ()->Get (
440- qparams->scale_buffer_idx ());
441- const std::string& data_name =
442- scale_buffer_offset->named_key ()->str ();
443- Result<FreeableBuffer> scale_buffer =
444- named_data_map->get_data (data_name.c_str ());
452+ scale = reinterpret_cast <const float *>(getConstantDataPtr (
453+ qparams->scale_buffer_idx (),
454+ flatbuffer_graph,
455+ constant_data_ptr,
456+ named_data_map,
457+ freeable_buffers,
458+ weights_cache));
445459 ET_CHECK_OR_RETURN_ERROR (
446- scale_buffer.ok (),
447- Internal,
448- " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
449- data_name.c_str (),
450- static_cast <uint32_t >(scale_buffer.error ()));
451- scale = reinterpret_cast <const float *>(scale_buffer.get ().data ());
452- freeable_buffers.push_back (std::move (scale_buffer.get ()));
460+ scale != nullptr , Internal, " Failed to load scale data." );
453461 }
454462 status = xnn_define_channelwise_quantized_tensor_value_v2 (
455463 /* subgraph=*/ subgraph_ptr,
@@ -483,22 +491,15 @@ Error defineTensor(
483491 // Block scales are preferably serialized as bf16 but can also be
484492 // serialized as fp32 for backwards compatability.
485493 if (qparams->scale_buffer_idx () != 0 ) {
486- ConstantDataOffsetPtr scale_buffer_offset =
487- flatbuffer_graph-> constant_data ()-> Get (
488- qparams-> scale_buffer_idx ());
489- const std::string& data_name =
490- scale_buffer_offset-> named_key ()-> str ();
491- Result<FreeableBuffer> scale_buffer =
492- named_data_map-> get_data (data_name. c_str ( ));
494+ scale_data = reinterpret_cast < const uint16_t *>( getConstantDataPtr (
495+ qparams-> scale_buffer_idx (),
496+ flatbuffer_graph,
497+ constant_data_ptr,
498+ named_data_map,
499+ freeable_buffers,
500+ weights_cache ));
493501 ET_CHECK_OR_RETURN_ERROR (
494- scale_buffer.ok (),
495- Internal,
496- " Failed to get constant data for key %s from named_data_map. Error code: %u" ,
497- data_name.c_str (),
498- static_cast <uint32_t >(scale_buffer.error ()));
499- scale_data =
500- reinterpret_cast <const uint16_t *>(scale_buffer.get ().data ());
501- freeable_buffers.push_back (std::move (scale_buffer.get ()));
502+ scale_data != nullptr , Internal, " Failed to load scale data." );
502503 scale_numel = qparams->num_scales ();
503504 } else {
504505 // Read fp32 scales, convert to bf16.
0 commit comments