From e38d886935c9e2004f72522bf11573d43f46b383 Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Thu, 4 Jan 2024 17:33:11 -0800 Subject: [PATCH] Fix structural pruning sparsity notebook - Enabled tensor preservation option explicitly when creating TFLite interpreter. - The convolution weight search is now done via the operator lookup. PiperOrigin-RevId: 595845109 --- .../guide/pruning/pruning_with_sparsity_2_by_4.ipynb | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb index c8be15cdf..962d0baad 100644 --- a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_sparsity_2_by_4.ipynb @@ -74,7 +74,7 @@ "id": "FbORZA_bQx1G" }, "source": [ - "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports. \n", + "Structural pruning weights from your model to make it sparse in specific pattern can accelerate model inference time with appropriate HW supports.\n", "\n", "This tutorial shows you how to:\n", "* Define and train a model on the mnist dataset with a specific structural sparsity\n", @@ -459,7 +459,7 @@ "outputs": [], "source": [ "# Load tflite file with the created pruned model\n", - "interpreter = tf.lite.Interpreter(model_path=tflite_file)\n", + "interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)\n", "interpreter.allocate_tensors()\n", "\n", "details = interpreter.get_tensor_details()\n", @@ -630,9 +630,10 @@ "outputs": [], "source": [ "# Get weights of the convolutional layer that has been pruned with 2 by 4 sparsity.\n", - "tensor_name = 'structural_pruning/Conv2D'\n", - "detail = [x for x in details if tensor_name in x[\"name\"]]\n", - "tensor_data = interpreter.tensor(detail[1][\"index\"])()\n", + "op_details = interpreter._get_ops_details()\n", + "op_name = 'CONV_2D'\n", + "op_detail = [x for x in op_details if op_name in x[\"op_name\"]]\n", + "tensor_data = interpreter.tensor(op_detail[1][\"inputs\"][1])()\n", "print(f\"Shape of the weight tensor is {tensor_data.shape}\")" ] },