Skip to content

Commit

Permalink
Ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 12, 2025
1 parent b078195 commit cd33e0a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
21 changes: 10 additions & 11 deletions benchmarks/microbenchmarks/bench_inference_kernels.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import torch
import argparse
import re
from copy import deepcopy
from typing import Callable, List, Optional

import torch

from benchmarks.microbenchmarks.utils import (
ToyLinearModel,
get_default_device,
)
from torchao import quantization
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_weight_only,
float8_weight_only,
quantize_,
)
import re


def parse_quantization_arg(quantization_input: List[str]):
Expand Down Expand Up @@ -61,9 +62,9 @@ def main(
base_model = ToyLinearModel().eval().to(device)

# Use quantize_ to apply each quantization function to the model
print(f"Running benchmark for {quant_func} {kwargs} quantization")
m_copy = deepcopy(base_model).to(device)
quantize_(m_copy, quant_func(**kwargs))
print(f"Running benchmark for {quant_func} {quant_kwargs} quantization")
m_copy = deepcopy(base_model).eval().to(device)
quantize_(m_copy, quant_func(**quant_kwargs))
print(f"Quantized model: {m_copy}")

if compile:
Expand All @@ -75,15 +76,13 @@ def main(


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run microbenchmarks"
)
parser = argparse.ArgumentParser(description="Run microbenchmarks")

parser.add_argument(
"-q",
"--quantization",
type=str,
nargs='+',
nargs="+",
help=(
"Pass all the quantization techniques for benchmarking: "
+ "int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
Expand Down
12 changes: 8 additions & 4 deletions benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def __init__(self, m=64, n=32, k=64, dtype=torch.bfloat16):
self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype)

def example_inputs(self, m=1, device="cuda"):
return (torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),)
return (
torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),
)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -17,7 +19,9 @@ def forward(self, x):

def get_default_device() -> str:
return (
"cuda" if torch.cuda.is_available() else
"xpu" if torch.xpu.is_available() else
"cpu"
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)

0 comments on commit cd33e0a

Please sign in to comment.