You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Raspbery Pico 2 demo , trying to use quantization to int8 instead and memory utilization doubled from 120KB to 220KB. Why is the memory consumption on the allocator increasing
#15339
Then I create a new script export_mlp_mnist_int8.py that exports a file called balanced_tiny_mlp_mnist_quantized.pte that I will use to compile with build_firmware_pico.sh. This will be quantized model compiled into a pico2 .uf2 file
I execute the DEMO with the unquantized model balanced_tiny_mlp_mnist.pte and everything works proving that there is only need for 120KB for allocator size
Then I try to compile the balanced_tiny_mlp_mnist_quantized.pte still with the same allocator size of 120KB and does not load the model because of insufficient
memory error 33 in decimal, 0x21 hex
/// Could not allocate the requested memory.
MemoryAllocationFailed = 0x21,
Then I increase the allocator size to 280KB and recompile the main.cpp and now it works
Why did the memory allocator utilization increased and doubled? I thought with quantization both the ROM size and RAM size will reduce
Below are all my steps. Thanks for helping me understand why the memory usage increase and maybe I am doing some wrong on the quantization steps below
TO BUILD the so library for obtaining the quantized model, I run examples/xnnpack/quantization/test_quantize.sh cmake add
This produces cmake-out/kernels/quantized/libquantized_ops_aot_lib.so
Get the .pte file with python examples/raspberry_pi/pico2/export_mlp_mnist.py
Get the .pte file for the quantized model
THE MODEL IS QUANTIZED with this new script I created export_mlp_mnist_int8.py.
It uses torch.ops.load_library("cmake-out/kernels/quantized/libquantized_ops_aot_lib.so")
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from executorch.exir import EdgeCompileConfig, to_edge
import logging
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export
from enum import Enum
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.exir.capture._config import ExecutorchBackendConfig
class QuantType(Enum):
NONE = 1
# Used for Operations that don't have weights
STATIC_PER_TENSOR = 2
# Used best for CNN/RNN Models with Conv layers
STATIC_PER_CHANNEL = 3
# Used for Linear Layers and Transformer Based Models
DYNAMIC_PER_CHANNEL = 4
# Constants
INPUT_SIZE = 784 # 28*28 for MNIST
HIDDEN1_SIZE = 32
HIDDEN2_SIZE = 16
OUTPUT_SIZE = 10
IMAGE_SIZE = 28
class TinyMLPMNIST(nn.Module):
"""A small MLP for MNIST digit classification."""
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(INPUT_SIZE, HIDDEN1_SIZE)
self.fc2 = nn.Linear(HIDDEN1_SIZE, HIDDEN2_SIZE)
self.fc3 = nn.Linear(HIDDEN2_SIZE, OUTPUT_SIZE)
def forward(self, x):
"""Forward pass through the network."""
x = x.reshape(x.size(0), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def create_balanced_model():
"""
Create a balanced MLP model for MNIST digit classification.
The model is pre-initialized with specific weights to recognize
digits 0, 1, 4, and 7 through hand-crafted feature detectors.
Returns:
torch.nn.Module: A TinyMLPMNIST model with balanced weights
"""
model = TinyMLPMNIST()
with torch.no_grad():
# Zero everything first
for param in model.parameters():
param.fill_(0.0)
# Feature 0: Vertical lines (for 1, 4, 7)
for row in range(IMAGE_SIZE):
# Middle column
model.fc1.weight[0, row * IMAGE_SIZE + 14] = 2.0
# Feature 1: Top horizontal (for 7, 4)
model.fc1.weight[1, 0:84] = 2.0 # Top 3 rows
# Feature 2: Bottom horizontal (for 1, 4)
model.fc1.weight[2, 25 * IMAGE_SIZE :] = 2.0 # Bottom 3 rows
# Feature 3: STRONGER Oval detector for 0
# Top and bottom curves
model.fc1.weight[3, 1 * IMAGE_SIZE + 8 : 1 * IMAGE_SIZE + 20] = 2.0
model.fc1.weight[3, 26 * IMAGE_SIZE + 8 : 26 * IMAGE_SIZE + 20] = 2.0
# Left and right sides
for row in range(4, 24):
model.fc1.weight[3, row * IMAGE_SIZE + 7] = 2.0 # Left
model.fc1.weight[3, row * IMAGE_SIZE + 20] = 2.0 # Right
# Anti-middle (hollow center)
for row in range(10, 18):
model.fc1.weight[3, row * IMAGE_SIZE + 14] = -1.5
# Feature 4: Middle horizontal (for 4) - make it STRONGER
model.fc1.weight[4, 13 * IMAGE_SIZE : 15 * IMAGE_SIZE] = 3.0
# Second layer: More decisive detection
# Digit 0 detector: STRONG oval requirement
model.fc2.weight[0, 3] = 5.0 # Strong oval requirement
model.fc2.weight[0, 0] = -2.0 # Anti-vertical
model.fc2.weight[0, 4] = -3.0 # Anti-middle horizontal
# Digit 1 detector: vertical + bottom - others
model.fc2.weight[1, 0] = 3.0 # Vertical
model.fc2.weight[1, 2] = 2.0 # Bottom
model.fc2.weight[1, 1] = -1.0 # Anti-top
model.fc2.weight[1, 3] = -2.0 # Anti-oval
# Digit 4 detector: REQUIRES middle horizontal
model.fc2.weight[2, 0] = 2.0 # Vertical
model.fc2.weight[2, 1] = 1.0 # Top
model.fc2.weight[2, 4] = 4.0 # STRONG middle requirement
model.fc2.weight[2, 3] = -2.0 # Anti-oval
# Digit 7 detector: top + some vertical - bottom
model.fc2.weight[3, 1] = 3.0 # Top
model.fc2.weight[3, 0] = 1.0 # Some vertical
model.fc2.weight[3, 2] = -2.0 # Anti-bottom
# Output layer
model.fc3.weight[0, 0] = 5.0 # Digit 0
model.fc3.weight[1, 1] = 5.0 # Digit 1
model.fc3.weight[4, 2] = 5.0 # Digit 4
model.fc3.weight[7, 3] = 5.0 # Digit 7
# Bias against other digits
for digit in [2, 3, 5, 6, 8, 9]:
model.fc3.bias[digit] = -3.0
return model
def test_comprehensive(model):
"""
Test model with clear digit patterns.
Args:
model: The PyTorch model to test
"""
# Create clearer test patterns
def create_digit_1():
"""Create a test pattern for digit 1."""
digit = torch.zeros(1, IMAGE_SIZE, IMAGE_SIZE)
# Thick vertical line in middle
digit[0, 2:26, 13:16] = 1.0 # Thick vertical line
# Top part (like handwritten 1)
digit[0, 2:5, 11:14] = 1.0
# Bottom base
digit[0, 24:27, 10:19] = 1.0
return digit
def create_digit_7():
"""Create a test pattern for digit 7."""
digit = torch.zeros(1, IMAGE_SIZE, IMAGE_SIZE)
# Top horizontal line
digit[0, 1:4, 3:26] = 1.0
# Diagonal line
for i in range(23):
row = 4 + i
col = 23 - i
if 0 <= row < IMAGE_SIZE and 0 <= col < IMAGE_SIZE:
digit[0, row, col - 1 : col + 2] = 1.0 # Thick diagonal
return digit
def create_digit_0():
"""Create a test pattern for digit 0."""
digit = torch.zeros(1, IMAGE_SIZE, IMAGE_SIZE)
# Oval shape
for row in range(3, 25):
for col in range(8, 20):
condition1 = ((row - 14) ** 2 / 11**2 + (col - 14) ** 2 / 6**2) <= 1
condition2 = ((row - 14) ** 2 / 8**2 + (col - 14) ** 2 / 3**2) > 1
if condition1 and condition2:
digit[0, row, col] = 1.0
return digit
patterns = {
"Digit 1": create_digit_1(),
"Digit 7": create_digit_7(),
"Digit 0": create_digit_0(),
}
print("🧪 Testing with clear patterns:")
model.eval()
with torch.no_grad():
for name, pattern in patterns.items():
output = model(pattern)
pred = output.argmax().item()
confidence = F.softmax(output, dim=1)[0, pred].item()
print(f" {name} → predicted: {pred} (confidence: {confidence:.3f})")
# Show top 3 predictions
top3 = output.topk(3, dim=1)
predictions = [
(top3.indices[0, i].item(), top3.values[0, i].item()) for i in range(3)
]
print(f" Top 3: {predictions}")
def quantize(
model, example_inputs, quant_type: QuantType = QuantType.STATIC_PER_TENSOR
):
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
logging.info(f"Original model: {model}")
quantizer = XNNPACKQuantizer()
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
is_per_channel = (
quant_type == QuantType.STATIC_PER_CHANNEL
or quant_type == QuantType.DYNAMIC_PER_CHANNEL
)
is_dynamic = quant_type == QuantType.DYNAMIC_PER_CHANNEL
operator_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=is_dynamic,
)
quantizer.set_global(operator_config)
m = prepare_pt2e(model, quantizer)
# calibration
m(*example_inputs)
m = convert_pt2e(m)
logging.info(f"Quantized model: {m}")
# make sure we can export to flat buffer
return m
def main():
"""Main function to create, test, and export the model."""
print("🔥 Creating balanced MLP MNIST model...")
model = create_balanced_model()
# Test the model
test_comprehensive(model)
# Export
example_inputs = torch.randn(1, IMAGE_SIZE, IMAGE_SIZE)
param_count = sum(p.numel() for p in model.parameters())
print(f"\n📊 Model parameters: {param_count:,}")
print("📦 Exporting...")
# pre-autograd export. eventually this will become torch.export
with torch.no_grad():
model = torch.export.export(model, (example_inputs,), strict=True).module()
print("📦 Quantizing...")
start = time.perf_counter()
quantized_model = quantize(model, example_inputs)
end = time.perf_counter()
print(f"Quantize time: {end - start}s")
print("📦 Exporting Quantized...")
start = time.perf_counter()
edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
edge_m = export_to_edge(
quantized_model, (example_inputs,), edge_compile_config=edge_compile_config
)
end = time.perf_counter()
print(f"Export time: {end - start}s")
#torch.ops.load_library("cmake-out/kernels/quantized/libquantized_ops_lib.so")
torch.ops.load_library("cmake-out/kernels/quantized/libquantized_ops_aot_lib.so")
start = time.perf_counter()
prog = edge_m.to_executorch(
config=ExecutorchBackendConfig(extract_delegate_segments=False)
)
save_pte_program(prog, "balanced_tiny_mlp_mnist_quantized")
end = time.perf_counter()
print(f"Save time: {end - start}s")
print("finished")
#print("📦 Exporting...")
#with torch.no_grad():
# exported_program = export(model, (example_input,))
#print("⚙️ Converting to ExecuTorch...")
#edge_config = EdgeCompileConfig(_check_ir_validity=False)
#edge_manager = to_edge(exported_program, compile_config=edge_config)
#et_program = edge_manager.to_executorch()
# Save with error handling
#filename = "balanced_tiny_mlp_mnist.pte"
#print(f"💾 Saving {filename}...")
#try:
# with open(filename, "wb") as f:
# f.write(et_program.buffer)
# model_size_kb = len(et_program.buffer) / 1024
# print("✅ Export complete!")
# print(f"📁 Model size: {model_size_kb:.1f} KB")
#except IOError as e:
# print(f"❌ Failed to save model: {e}")
if __name__ == "__main__":
main()
NOW We have a quantized model balanced_tiny_mlp_mnist_quantized.pte. Look at the size difference
-rw-r--r-- 1 106216 Oct 19 10:33 balanced_tiny_mlp_mnist.pte
-rw-r--r-- 1 33104 Oct 21 20:09 balanced_tiny_mlp_mnist_quantized.pte
COMPILE THE DEMO with balanced_tiny_mlp_mnist.pte and change the memory allocator in the main.cpp to 120KB
static uint8_t method_allocator_pool[120 * 1024]; // it used to be 200KB - plenty for method metadata
static uint8_t activation_pool[120 * 1024]; // it used to be 200KB - plenty for activations
NOW WE START BUILDING THE RASPBERRY PI PICO FIRMWARE from the demo but with 120KB instead of 200KB. Notice it works
NOW Build the quantized version with the same 120KB size for allocator
In build_firmware_pico.sh we add
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON \
Then in we add the quantized libs in CMakeList.txt
-Wl,--whole-archive
${BAREMETAL_BUILD_DIR}/lib/libportable_ops_lib.a
${BAREMETAL_BUILD_DIR}/lib/libquantized_ops_lib.a
-Wl,--no-whole-archive
${BAREMETAL_BUILD_DIR}/lib/libportable_kernels.a
${BAREMETAL_BUILD_DIR}/lib/libquantized_kernels.a
COMPILE
examples/raspberry_pi/pico2/build_firmware_pico.sh --model=balanced_tiny_mlp_mnist_quantized.pte
[SERIAL/DIRECT] CONNECTED TO PORT COM15 (115200-8N1)
Loading model data (33104 bytes)...
✅ Program loaded successfully
📊 Program info:
Method count: 1
Method 0 name: forward
🔄 Loading method 'forward'...
offset_bytes (0) + size_bytes (3136) >= allocator size (122880) for memory_id 0
offset_bytes **(100352) + size_bytes (100352) >= allocator size (122880)** for memory_id 0
❌ Failed to load method: error 33
→ **Not enough memory t**o load method
Failed to load and prepare model
CHANGE main.cpp allocator size to 280KB. and now it WORKS . BUT Why did the memory needs went up so much from an unquantized model to the quantized model
RECOMPILE the quantized model and it runs now
static uint8_t method_allocator_pool[120 * 1024];
static uint8_t activation_pool[280 * 1024];
Loading model data (33104 bytes)...
✅ Program loaded successfully
Program info:
Method count: 1
Method 0 name: forward
Loading method 'forward'...
offset_bytes (0) + size_bytes (3136) >= allocator size (286720) for memory_id 0
offset_bytes (100352) + size_bytes (100352) >= allocator size (286720) for memory_id 0
offset_bytes (203840) + size_bytes (2048) >= allocator size (286720) for memory_id 0
offset_bytes (205888) + size_bytes (640) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (3136) >= allocator size (286720) for memory_id 0
offset_bytes (3136) + size_bytes (784) >= allocator size (286720) for memory_id 0
offset_bytes (200704) + size_bytes (3136) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (100352) >= allocator size (286720) for memory_id 0
offset_bytes (100352) + size_bytes (128) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (128) >= allocator size (286720) for memory_id 0
offset_bytes (200704) + size_bytes (32) >= allocator size (286720) for memory_id 0
offset_bytes (100352) + size_bytes (128) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (2048) >= allocator size (286720) for memory_id 0
offset_bytes (200704) + size_bytes (64) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (64) >= allocator size (286720) for memory_id 0
offset_bytes (200704) + size_bytes (16) >= allocator size (286720) for memory_id 0
offset_bytes (100352) + size_bytes (64) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (640) >= allocator size (286720) for memory_id 0
offset_bytes (200704) + size_bytes (40) >= allocator size (286720) for memory_id 0
offset_bytes (100352) + size_bytes (10) >= allocator size (286720) for memory_id 0
offset_bytes (0) + size_bytes (40) >= allocator size (286720) for memory_id 0
✅ Method 'forward' loaded successfully
ExecuTorch MLP MNIST Demo (Neural network) on Pico2 (microcontroller)
離 Testing all supported digits:
Input stats: 159 white pixels out of 784 total
Running neural network inference...
Execute Start time
Execute End
✅ Neural network results:
Digit 0: 367.468
Digit 1: 0.000
Digit 2: -2.763
Digit 3: -2.763
Digit 4: 436.540 ← PREDICTED
Digit 5: -2.763
Digit 6: -2.763
Digit 7: 375.756
Digit 8: -2.763
Digit 9: -2.763
🎯 PREDICTED: 4 (Expected: 7) ❌ WRONG!
==================================================
🎉 All tests complete! ExecuTorch inference of neural network works on Pico2!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi guys
I am trying to follow along the raspberry pi pico2 demo from last week (https://github.com/pytorch/executorch/tree/main/examples/raspberry_pi/pico2) but with the modification that I want to quantized the model to int8
First below you can see that I first try to get the libquantized_ops_aot_lib.so with the steps provided in the links that are in the README (https://github.com/pytorch/executorch/blob/main/examples/raspberry_pi/pico2/README.md) for quantization
Then I create a new script export_mlp_mnist_int8.py that exports a file called balanced_tiny_mlp_mnist_quantized.pte that I will use to compile with build_firmware_pico.sh. This will be quantized model compiled into a pico2 .uf2 file
Below you are going to see that I can build the original demo and execute on my pico2. However to prove my point I changed the allocator size to 120KB instead of the 200KB that is shown in the git hub main.cpp file (https://github.com/pytorch/executorch/blob/main/examples/raspberry_pi/pico2/main.cpp#L358)
I execute the DEMO with the unquantized model balanced_tiny_mlp_mnist.pte and everything works proving that there is only need for 120KB for allocator size
Then I try to compile the balanced_tiny_mlp_mnist_quantized.pte still with the same allocator size of 120KB and does not load the model because of insufficient
memory error 33 in decimal, 0x21 hex
/// Could not allocate the requested memory.
MemoryAllocationFailed = 0x21,
Then I increase the allocator size to 280KB and recompile the main.cpp and now it works
Why did the memory allocator utilization increased and doubled? I thought with quantization both the ROM size and RAM size will reduce
Below are all my steps. Thanks for helping me understand why the memory usage increase and maybe I am doing some wrong on the quantization steps below
cd pico/executorch
python3 -m venv .pico3 && source .pico3/bin/activate
./install_executorch.sh
examples/arm/setup.sh --i-agree-to-the-contained-eula
source examples/arm/ethos-u-scratch/setup_path.sh
export PATH="~/pico/executorch/examples/arm/ethos-u-scratch/arm-gnu-toolchain-13.3.rel1-x86_64-arm-none-eabi/bin/:$PATH"
TO BUILD the so library for obtaining the quantized model, I run examples/xnnpack/quantization/test_quantize.sh cmake add
This produces cmake-out/kernels/quantized/libquantized_ops_aot_lib.so
Get the .pte file with python examples/raspberry_pi/pico2/export_mlp_mnist.py
Get the .pte file for the quantized model
THE MODEL IS QUANTIZED with this new script I created export_mlp_mnist_int8.py.
It uses torch.ops.load_library("cmake-out/kernels/quantized/libquantized_ops_aot_lib.so")
NOW We have a quantized model balanced_tiny_mlp_mnist_quantized.pte. Look at the size difference
-rw-r--r-- 1 106216 Oct 19 10:33 balanced_tiny_mlp_mnist.pte
-rw-r--r-- 1 33104 Oct 21 20:09 balanced_tiny_mlp_mnist_quantized.pte
COMPILE THE DEMO with balanced_tiny_mlp_mnist.pte and change the memory allocator in the main.cpp to 120KB
static uint8_t method_allocator_pool[120 * 1024]; // it used to be 200KB - plenty for method metadata
static uint8_t activation_pool[120 * 1024]; // it used to be 200KB - plenty for activations
NOW WE START BUILDING THE RASPBERRY PI PICO FIRMWARE from the demo but with 120KB instead of 200KB. Notice it works
NOW FLASH
NOW Build the quantized version with the same 120KB size for allocator
In build_firmware_pico.sh we add
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT=ON \
Then in we add the quantized libs in CMakeList.txt
-Wl,--whole-archive
${BAREMETAL_BUILD_DIR}/lib/libportable_ops_lib.a
${BAREMETAL_BUILD_DIR}/lib/libquantized_ops_lib.a
-Wl,--no-whole-archive
${BAREMETAL_BUILD_DIR}/lib/libportable_kernels.a
${BAREMETAL_BUILD_DIR}/lib/libquantized_kernels.a
COMPILE
examples/raspberry_pi/pico2/build_firmware_pico.sh --model=balanced_tiny_mlp_mnist_quantized.pte
[SERIAL/DIRECT] CONNECTED TO PORT COM15 (115200-8N1)
CHANGE main.cpp allocator size to 280KB. and now it WORKS . BUT Why did the memory needs went up so much from an unquantized model to the quantized model
RECOMPILE the quantized model and it runs now
static uint8_t method_allocator_pool[120 * 1024];
static uint8_t activation_pool[280 * 1024];
Beta Was this translation helpful? Give feedback.
All reactions