Skip to content

Commit 7ff1e42

Browse files
authored
Quantize vit_b_16 tutorial - Part 1 (pytorch#60)
1 parent 2871d74 commit 7ff1e42

File tree

9 files changed

+4214
-8
lines changed

9 files changed

+4214
-8
lines changed

torchao/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from . import dtypes
2+
from .quantization.quant_api import apply_dynamic_quant
3+
from .quantization.quant_api import apply_weight_only_int8_quant
24

35
__all__ = [
4-
"dtypes"
6+
"dtypes",
7+
"apply_dynamic_quant",
58
]

torchao/quantization/quant_api.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,10 @@ def apply_weight_only_int8_quant(model, filter_fn=None):
126126
def apply_dynamic_quant(model, filter_fn=None):
127127
"""
128128
Applies dynamic symmetric per-token activation and per-channel weight
129-
quantization to all linear layers in the given model using
130-
module swaps.
129+
quantization to all linear layers by converting all linear weight
130+
tensors to the `Int8DynamicallyQuantizedLinearWeight` Tensor subclass.
131131
"""
132-
_replace_with_custom_fn_if_matches_filter(
133-
model,
134-
lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod),
135-
_is_linear if filter_fn is None else filter_fn,
136-
)
132+
change_linear_weights_to_int8_dqtensors(model, filter_fn)
137133

138134

139135
def _get_subclass_inserter(cls, **kwargs):
15.9 KB
Binary file not shown.

tutorials/quantize_vit/bfloat16_code.py

+1,682
Large diffs are not rendered by default.

tutorials/quantize_vit/quant.json.gz

15.8 KB
Binary file not shown.

tutorials/quantize_vit/quant_code.py

+2,413
Large diffs are not rendered by default.

tutorials/quantize_vit/run.sh

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#!/bin/bash
2+
3+
# Run bfloat16 version
4+
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b.py
5+
6+
# Run dynamic quantized version
7+
TORCH_LOGS='graph_breaks,recompiles' python run_vit_b_quant.py
8+
9+
# Store the output code for further inspection
10+
echo "bfloat16 generated code lives in:"
11+
TORCH_LOGS='output_code' python run_vit_b.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'
12+
echo "quantization generated code lives in:"
13+
TORCH_LOGS='output_code' python run_vit_b_quant.py 2>&1 | grep "Output code written to: " | awk -F" " '{print $NF}'

tutorials/quantize_vit/run_vit_b.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torchvision.models.vision_transformer as models
3+
4+
# Load Vision Transformer model
5+
model = models.vit_b_16(pretrained=True)
6+
7+
# Set the model to evaluation mode
8+
model.eval().cuda().to(torch.bfloat16)
9+
10+
# Input tensor (batch_size, channels, height, width)
11+
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
12+
13+
model = torch.compile(model, mode='max-autotune')
14+
15+
def benchmark_model(model, num_runs, input_tensor):
16+
torch.cuda.synchronize()
17+
start_event = torch.cuda.Event(enable_timing=True)
18+
end_event = torch.cuda.Event(enable_timing=True)
19+
start_event.record()
20+
21+
# benchmark
22+
for _ in range(num_runs):
23+
with torch.autograd.profiler.record_function("timed region"):
24+
model(input_tensor)
25+
26+
end_event.record()
27+
torch.cuda.synchronize()
28+
return start_event.elapsed_time(end_event) / num_runs
29+
30+
def profiler_runner(path, fn, *args, **kwargs):
31+
with torch.profiler.profile(
32+
activities=[torch.profiler.ProfilerActivity.CPU,
33+
torch.profiler.ProfilerActivity.CUDA],
34+
record_shapes=True) as prof:
35+
result = fn(*args, **kwargs)
36+
prof.export_chrome_trace(path)
37+
return result
38+
39+
# Must run with no_grad when optimizing for inference
40+
with torch.no_grad():
41+
# warmup
42+
benchmark_model(model, 5, input_tensor)
43+
# benchmark
44+
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
45+
# Create a trace
46+
profiler_runner("bfloat16.json.gz", benchmark_model, model, 5, input_tensor)
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torchao
3+
import torchvision.models.vision_transformer as models
4+
5+
# Load Vision Transformer model
6+
model = models.vit_b_16(pretrained=True)
7+
8+
# Set the model to evaluation mode
9+
model.eval().cuda().to(torch.bfloat16)
10+
11+
# Input tensor (batch_size, channels, height, width)
12+
input_tensor = torch.randn(1, 3, 224, 224, dtype=torch.bfloat16, device='cuda')
13+
14+
## Quantization code - start
15+
torchao.apply_dynamic_quant(model)
16+
from torch._inductor import config as inductorconfig
17+
inductorconfig.force_fuse_int_mm_with_mul = True
18+
## Quantization code - end
19+
20+
model = torch.compile(model, mode='max-autotune')
21+
22+
def benchmark_model(model, num_runs, input_tensor):
23+
torch.cuda.synchronize()
24+
start_event = torch.cuda.Event(enable_timing=True)
25+
end_event = torch.cuda.Event(enable_timing=True)
26+
start_event.record()
27+
28+
# benchmark
29+
for _ in range(num_runs):
30+
with torch.autograd.profiler.record_function("timed region"):
31+
model(input_tensor)
32+
33+
end_event.record()
34+
torch.cuda.synchronize()
35+
return start_event.elapsed_time(end_event) / num_runs
36+
37+
def profiler_runner(path, fn, *args, **kwargs):
38+
with torch.profiler.profile(
39+
activities=[torch.profiler.ProfilerActivity.CPU,
40+
torch.profiler.ProfilerActivity.CUDA],
41+
record_shapes=True) as prof:
42+
result = fn(*args, **kwargs)
43+
prof.export_chrome_trace(path)
44+
return result
45+
46+
# Must run with no_grad when optimizing for inference
47+
with torch.no_grad():
48+
# warmup
49+
benchmark_model(model, 5, input_tensor)
50+
# benchmark
51+
print("elapsed_time: ", benchmark_model(model, 100, input_tensor), " milliseconds")
52+
# Create a trace
53+
profiler_runner("quant.json.gz", benchmark_model, model, 5, input_tensor)

0 commit comments

Comments
 (0)