Skip to content

Commit

Permalink
Updates for yml parser
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 20, 2025
1 parent 6ad94dc commit 48432b9
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 36 deletions.
25 changes: 17 additions & 8 deletions benchmarks/microbenchmarks/bench_inference_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ def run(
model_type: str = "linear",
compile: bool = False,
device=get_default_device(),
output_dir: str = "benchmarks/microbenchmarks/results/",
) -> None:
# TODO: Add more model types here
clean_caches()
base_model, input_data = create_model_and_input(
model_type, m, k, n,
dtype=precision,
device=device,)
device=device,
)
print(f"Starting benchmarking for model: {base_model.__class__.__name__} for quantization: {quantization}")
# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(device)
Expand All @@ -49,16 +51,15 @@ def run(
# 2. Benchmark time using profiler
# Profile dtype model evaluation
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
# prof_dtype.export_chrome_trace(f"dtype_model_{input_data[0].size()[0]}.json") # Save profiling details

# Calculate and store GPU kernel times -> op time, overhead time
# dtype_gpu_op_time, dtype_gpu_overhead_time = get_gpu_kernel_times(prof_dtype, 'gemm')
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details

# 3. Benchmark gemm time without profiler
# matmul_time (without profiler for a quantized tensor)
# 3. Benchmark gemm time using cuda graph
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)

# 6. Create csv file with all the results
# 4. Benchmark op with cuda graph
# time = benchmark_op_with_cuda_graph(op, args)

# Last: Create csv file with all the results
# generate_csv()


Expand Down Expand Up @@ -116,6 +117,13 @@ def run(
help="Device to run the model on",
)

parser.add_argument(
"--output_dir",
type=str,
default="benchmarks/microbenchmarks/results/",
help="Output directory to save results",
)

args = parser.parse_args()
print(args)

Expand All @@ -128,4 +136,5 @@ def run(
precision=args.precision,
compile=args.compile,
device=args.device,
output_dir=args.output_dir,
)
7 changes: 0 additions & 7 deletions benchmarks/microbenchmarks/config.yml

This file was deleted.

30 changes: 22 additions & 8 deletions benchmarks/microbenchmarks/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
from typing import Dict, List, Any, Tuple
from pathlib import Path
from itertools import product
from utils import get_name_to_shapes_iter # Import the shape utility

class BenchmarkConfig:
def __init__(self, quantization: str, params: Dict[str, Any], matrix_shape: List[int]):
def __init__(self, quantization: str, params: Dict[str, Any], shape_name: str, shape: List[int]):
self.quantization = quantization
self.m, self.k, self.n = matrix_shape
self.m, self.k, self.n = shape
self.shape_name = shape_name
self.precision = self._parse_precision(params['precision'])
self.compile = params.get('compile', False)
self.device = params.get('device', 'cuda')
self.name = f'benchmark_{self.quantization}_m{self.m}_k{self.k}_n{self.n}'
self.model_type = params.get('model_type', 'linear')
self.name = f'benchmark_{self.quantization}_{self.shape_name}_m{self.m}_k{self.k}_n{self.n}'

@staticmethod
def _parse_precision(precision_str: str) -> torch.dtype:
Expand All @@ -27,22 +30,33 @@ def to_dict(self) -> Dict[str, Any]:
'n': self.n,
'precision': self.precision,
'compile': self.compile,
'device': self.device
'device': self.device,
'model_type': self.model_type,
}

def get_shapes_for_config(shape_config: Dict[str, Any]) -> List[Tuple[str, List[int]]]:
"""Get shapes for a given configuration"""
name = shape_config['name']
if name == "custom":
return [(name, shape) for shape in shape_config['shapes']]
else:
return [(name, shape) for shape in get_name_to_shapes_iter(name)]

def load_benchmark_configs(config_path: str) -> List[BenchmarkConfig]:
"""Load benchmark configurations from YAML file"""
with open(config_path, 'r') as f:
config_data = yaml.safe_load(f)

quantizations = config_data['quantizations']
params = config_data['model_params']
matrix_shapes = params['matrix_shapes']

configs = []
# Generate all combinations of quantizations and matrix shapes
for quant, shape in product(quantizations, matrix_shapes):
configs.append(BenchmarkConfig(quant, params, shape))
# Process each shape configuration
for shape_config in params['matrix_shapes']:
shapes = get_shapes_for_config(shape_config)
# Generate combinations for each shape
for quant, (shape_name, shape) in product(quantizations, shapes):
configs.append(BenchmarkConfig(quant, params, shape_name, shape))

return configs

Expand Down
21 changes: 14 additions & 7 deletions benchmarks/microbenchmarks/configs/benchmark_config.yml
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
# Default configuration for inference kernel benchmarks
# Sample configuration for inference kernel benchmarks

# For multiple quantizations and shapes
quantizations:
- "baseline"
- "int8wo"
- "int4wo-128"
- "int4wo-128-hqq"

model_params:
matrix_shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
# - name: "llama"
# shapes: [] # Will be populated from utils.get_name_to_shapes_iter
precision: "torch.bfloat16"
compile: false
compile: false
device: "cuda"
model_type: "linear"
5 changes: 0 additions & 5 deletions benchmarks/microbenchmarks/test/test_inference.py

This file was deleted.

Empty file.
68 changes: 67 additions & 1 deletion benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import List
from typing import List, Optional

import torch
from torch.profiler import ProfilerActivity, profile
Expand Down Expand Up @@ -100,6 +100,8 @@ def quantize_model(
precision = kwargs.get("precision", None)

# Quantization techniques
if "baseline" in quantization:
return model
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -397,3 +399,67 @@ def clean_caches():

if compile:
torch._dynamo.reset()

def get_name_to_shapes_iter(
shape_gen_name: str,
M: Optional[int],
K: Optional[int],
N: Optional[int],
):
if shape_gen_name == "llama":
assert (
M == K == N == None
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
bsz, seq_len = 4, 4096
M = bsz * seq_len
# LLaMa 2 70B single-node weight shapes
# assumes fused attn.wqkv and ffn.w13
# source: https://fburl.com/gsheet/g8onr7rh
name_to_shapes_70b = {
"attn.wqkv": (M, 8192, 1280),
"attn.w0": (M, 1024, 8192),
"ffn.w13": (M, 8192, 7168),
"ffn.w2": (M, 3584, 8192),
}
return name_to_shapes_70b.items()

elif shape_gen_name == "square":
assert (
M == K == N == None
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
name_to_shapes = {}
min_power_of_2 = 8 # 256
max_power_of_2 = 15 # 32,768
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
val = 2**power_of_2
name_to_shapes[idx] = val, val, val
return name_to_shapes.items()

elif shape_gen_name == "sweep":
assert (
M == K == N == None
), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}"
name_to_shapes = {}
min_p2 = 8 # 256
max_p2 = 15 # 32,768
counter = 0
for M_p2 in range(min_p2, max_p2 + 1):
M = 2**M_p2
for K_p2 in range(min_p2, max_p2 + 1):
K = 2**K_p2
for N_p2 in range(min_p2, max_p2 + 1):
N = 2**N_p2
name_to_shapes[counter] = M, K, N
counter += 1
return name_to_shapes.items()

elif shape_gen_name == "custom":
assert (
M is not None and K is not None and N is not None
), "M, K, N must be specified for custom shape_gen"
name_to_shapes = {
1: (M, K, N),
}
return name_to_shapes.items()

raise AssertionError(f"unknown shape_gen_name {shape_gen_name}")

0 comments on commit 48432b9

Please sign in to comment.