Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 18, 2025
1 parent e880d25 commit ba3390a
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 108 deletions.
185 changes: 84 additions & 101 deletions benchmarks/microbenchmarks/bench_inference_kernels.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,61 @@
"""Script to compare multiple quantization techniques for inference, for a particular matrix shape, and model type"""

import argparse
import re
from copy import deepcopy
from typing import Callable, List, Optional
from typing import List

import torch

from utils import (
ToyLinearModel,
benchmark_model_inference_time,
benchmark_model_inference_time_with_profiler,
create_model_and_input,
get_default_device,
quantize_model,
)
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_weight_only,
quantize_,
)


def parse_quantization_arg(quantization_input: List[str]):
# Define regex patterns for quantization techniques
patterns = {
r"^int4wo-(\d+)(-hqq)?$": int4_weight_only,
r"^int8wo$": int8_weight_only,
r"^float8wo$": float8_weight_only,
# Add other patterns and corresponding functions here
}
for quant_technique in quantization_input:
# Iterate over patterns and functions
for pattern, func in patterns.items():
match = re.match(pattern, quant_technique)
if match:
# Extract parameters from the match
groups = match.groups()
if func == int4_weight_only:
kwargs = {
"group_size": int(groups[0]),
"use_hqq": bool(groups[1]),
}
yield func, kwargs
elif func == int8_weight_only:
yield func
elif func == float8_weight_only:
yield func
# TODO: Add other function calls with parameters here

# raise ValueError(f"Unsupported quantization technique: {quant_technique}")


def main(
quant_func: Callable,
quant_kwargs: Optional[dict],
# matrix_sizes,
# m,
# k,
# n,
# precision,
device=get_default_device(),
quantizations: List[str],
m,
k,
n,
precision,
model_type: str = "linear",
compile: bool = False,
device=get_default_device(),
) -> None:
# TODO: Add more model types here
base_model = ToyLinearModel().eval().to(device)

# Use quantize_ to apply each quantization function to the model
print(f"Running benchmark for {quant_func} {quant_kwargs} quantization")
m_copy = deepcopy(base_model).eval().to(device)
quantize_(m_copy, quant_func(**quant_kwargs))
print(f"Quantized model: {m_copy}")

if compile:
print("Compiling model...")
m_copy = torch.compile(m_copy)

# TODO: Run benchmark on the quantized model
# Will add benchmarking code here
base_model, input_data = create_model_and_input(
model_type, m, k, n,
dtype=precision,
device=device,)
print(f"Starting benchmarking for model: {base_model.__class__.__name__}......")
for quant in quantizations:
# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(device)
m_copy = quantize_model(m_copy, quant)
# quantized_dtype = .....

if compile:
m_copy = torch.compile(m_copy)

# Run benchmarks
# 1. Benchmark time to run an inference call for quantized model
model_time = benchmark_model_inference_time(model=m_copy, input_data=input_data)
print(f"Time to run a {base_model.__class__.__name__}: {model_time * 1e6:.2f} microseconds quantized with {quant}")

# 2. Benchmark time using profiler

# Profile bf16 model evaluation
# prof_bf16 = benchmark_model_inference_time_with_profiler(m_copy, input_data, quantized_dtype)
# prof_bf16.export_chrome_trace(f"bf16_model_{input_data[0].size()[0]}.json") # Save profiling details

# Calculate and store GPU kernel times -> op time, overhead time
# dtype_gpu_op_time, dtype_gpu_overhead_time = get_gpu_kernel_times(prof_dtype, 'gemm')

# 6. Create csv file with all the results
# generate_csv()


if __name__ == "__main__":
Expand All @@ -89,54 +72,54 @@ def main(
),
)

# parser.add_argument(
# "--matrix_sizes",
# type=str,
# nargs='+',
# help=(
# "Pass all the matrix sizes for benchmarking."
# ),
# )

# parser.add_argument(
# "-m",
# type=int,
# help="M dimension of the matrix",
# )

# parser.add_argument(
# "-k",
# type=int,
# help="M dimension of the matrix",
# )

# parser.add_argument(
# "-n",
# type=int,
# help="M dimension of the matrix",
# )

# parser.add_argument(
# "--precision",
# type=str,
# choices=["float32", "float16", "bfloat16"],
# )
parser.add_argument(
"-m",
type=int,
help="M dimension of the matrix",
)

parser.add_argument(
"-k",
type=int,
help="M dimension of the matrix",
)

parser.add_argument(
"-n",
type=int,
help="M dimension of the matrix",
)

parser.add_argument(
"--precision",
type=lambda x: getattr(torch, x.split(".")[-1]),
default=torch.bfloat16,
help="dtype precision to use",
)

parser.add_argument(
"--compile",
action="store_true",
help="Whether to compile the model",
)

parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to run the model on",
)

args = parser.parse_args()
print(args)

# Process arguments
quantization_funcs = list(parse_quantization_arg(args.quantization))

# Run benchmarks
for func, kwargs in quantization_funcs:
main(
quant_func=func,
quant_kwargs=kwargs,
)
main(
quantizations=args.quantization,
m=args.m,
k=args.k,
n=args.n,
precision=args.precision,
compile=args.compile,
device=args.device,
)
5 changes: 5 additions & 0 deletions benchmarks/microbenchmarks/test/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from utils import (
get_default_device,
)

print(get_default_device())
Empty file.
Loading

0 comments on commit ba3390a

Please sign in to comment.