From 9a32964be99131002a83e494299fc44a3d39fdff Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 16 Dec 2024 09:48:04 -0700 Subject: [PATCH] feat(compression): allocate resource variables in persistent buffer (#3013) Allocate resource variables in a persistent buffer when the input tensor is compressed. Extend tests to validate operation. BUG=part of #2636 --- .../lite/micro/kernels/assign_variable.cc | 51 ++++++++++++++++++- .../lite/micro/micro_resource_variable.cc | 11 ++-- .../lite/micro/micro_resource_variable.h | 6 +-- .../micro/micro_resource_variable_test.cc | 7 ++- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/micro/kernels/assign_variable.cc b/tensorflow/lite/micro/kernels/assign_variable.cc index bd99bd1aa0c..9374279e9af 100644 --- a/tensorflow/lite/micro/kernels/assign_variable.cc +++ b/tensorflow/lite/micro/kernels/assign_variable.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_graph.h" #include "tensorflow/lite/micro/micro_log.h" #include "tensorflow/lite/micro/micro_resource_variable.h" +#include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -35,6 +36,20 @@ namespace { constexpr int kInputVariableId = 0; constexpr int kInputValue = 1; +#ifdef USE_TFLM_COMPRESSION + +struct OpData { + // scratch buffer for compressed input tensor + int scratch_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); + return context->AllocatePersistentBuffer(context, sizeof(OpData)); +} + +#endif // USE_TFLM_COMPRESSION + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0); @@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, input_value)); } +#ifdef USE_TFLM_COMPRESSION + + TFLITE_DCHECK(node->user_data != nullptr); + OpData* data = static_cast(node->user_data); + // Compression scratch buffers. + // These will only be allocated if the tensor is compressed. + data->scratch_index = + micro_context->AllocateDecompressionScratchBuffer(node, kInputValue); + +#endif // USE_TFLM_COMPRESSION + micro_context->DeallocateTempTfLiteTensor(input_value); return kTfLiteOk; } @@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { "ResourceVariables and pass it to the interpreter."); return kTfLiteError; } + +#ifdef USE_TFLM_COMPRESSION + OpData* data = static_cast(node->user_data); + const CompressionTensorData* comp_td = + micro_context->GetTensorCompressionData(node, kInputValue); + const void* buffer = tflite::micro::GetTensorData( + micro_context, input_value, comp_td, data->scratch_index); +#else // USE_TFLM_COMPRESSION + const void* buffer = tflite::micro::GetTensorData(input_value); +#endif // USE_TFLM_COMPRESSION + TF_LITE_ENSURE_OK(context, - resources->Assign(input_id->data.i32[0], input_value)); + resources->Assign(input_id->data.i32[0], + EvalTensorBytes(input_value), buffer)); return kTfLiteOk; } } // namespace. +#ifdef USE_TFLM_COMPRESSION + +TFLMRegistration Register_ASSIGN_VARIABLE() { + return tflite::micro::RegisterOp(Init, Prepare, Eval); + +#else // USE_TFLM_COMPRESSION + TFLMRegistration Register_ASSIGN_VARIABLE() { return tflite::micro::RegisterOp(nullptr, Prepare, Eval); + +#endif // USE_TFLM_COMPRESSION } } // namespace tflite diff --git a/tensorflow/lite/micro/micro_resource_variable.cc b/tensorflow/lite/micro/micro_resource_variable.cc index 767e7d17d6f..843aac664bc 100644 --- a/tensorflow/lite/micro/micro_resource_variable.cc +++ b/tensorflow/lite/micro/micro_resource_variable.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -113,8 +113,8 @@ TfLiteStatus MicroResourceVariables::Allocate(int id, TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus MicroResourceVariables::Assign(int id, - const TfLiteEvalTensor* tensor) { +TfLiteStatus MicroResourceVariables::Assign(int id, size_t count_bytes, + const void* input_buffer) { if (id < 0 || id >= num_resource_variables_) { MicroPrintf("Attempting to read non-existent resource variable %d", id); return kTfLiteError; @@ -128,8 +128,9 @@ TfLiteStatus MicroResourceVariables::Assign(int id, "with a TfLiteTensor first."); return kTfLiteError; } - TFLITE_DCHECK(EvalTensorBytes(tensor) == variable.bytes); - memcpy(variable.resource_buffer, tensor->data.raw, variable.bytes); + TFLITE_DCHECK(count_bytes == variable.bytes); + TFLITE_DCHECK(input_buffer != nullptr); + memcpy(variable.resource_buffer, input_buffer, variable.bytes); return kTfLiteOk; } diff --git a/tensorflow/lite/micro/micro_resource_variable.h b/tensorflow/lite/micro/micro_resource_variable.h index fb9917d4784..57da6497b3a 100644 --- a/tensorflow/lite/micro/micro_resource_variable.h +++ b/tensorflow/lite/micro/micro_resource_variable.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -46,10 +46,10 @@ class MicroResourceVariables { TfLiteStatus Allocate(int id, TfLiteContext* context, const TfLiteTensor* tensor); - // Copies input tensor contents to the resource buffer. + // Copies input_buffer contents to the resource buffer. // AllocateResourceVariable with a TFLite tensor must have been called first // in order to allocate the resource buffer. - TfLiteStatus Assign(int id, const TfLiteEvalTensor* tensor); + TfLiteStatus Assign(int id, size_t count_bytes, const void* input_buffer); // Zeros out all resource buffers. TfLiteStatus ResetAll(); diff --git a/tensorflow/lite/micro/micro_resource_variable_test.cc b/tensorflow/lite/micro/micro_resource_variable_test.cc index 13868bb440d..a30718cb994 100644 --- a/tensorflow/lite/micro/micro_resource_variable_test.cc +++ b/tensorflow/lite/micro/micro_resource_variable_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/micro/micro_resource_variable.h" #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/micro_utils.h" #include "tensorflow/lite/micro/test_helpers.h" #include "tensorflow/lite/micro/testing/micro_test.h" @@ -120,7 +121,9 @@ TF_LITE_MICRO_TEST(VerifyAssignAndReadResourceBuffer) { .type = kTfLiteFloat32, }; - resource_variables->Assign(id, &assign_tensor); + resource_variables->Assign( + id, tflite::EvalTensorBytes(&assign_tensor), + tflite::micro::GetTensorData(&assign_tensor)); int32_t buffer[32]; TfLiteEvalTensor read_tensor = {