Skip to content

Commit

Permalink
Ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Feb 13, 2025
1 parent b078195 commit e880d25
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
Empty file removed benchmarks/__init__.py
Empty file.
Empty file.
27 changes: 12 additions & 15 deletions benchmarks/microbenchmarks/bench_inference_kernels.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import torch
import argparse
import re
from copy import deepcopy
from typing import Callable, List, Optional
from benchmarks.microbenchmarks.utils import (

import torch

from utils import (
ToyLinearModel,
get_default_device,
)
from torchao import quantization
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
int8_weight_only,
float8_weight_only,
quantize_,
)
import re


def parse_quantization_arg(quantization_input: List[str]):
Expand Down Expand Up @@ -61,9 +62,9 @@ def main(
base_model = ToyLinearModel().eval().to(device)

# Use quantize_ to apply each quantization function to the model
print(f"Running benchmark for {quant_func} {kwargs} quantization")
m_copy = deepcopy(base_model).to(device)
quantize_(m_copy, quant_func(**kwargs))
print(f"Running benchmark for {quant_func} {quant_kwargs} quantization")
m_copy = deepcopy(base_model).eval().to(device)
quantize_(m_copy, quant_func(**quant_kwargs))
print(f"Quantized model: {m_copy}")

if compile:
Expand All @@ -75,20 +76,16 @@ def main(


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run microbenchmarks"
)
parser = argparse.ArgumentParser(description="Run microbenchmarks")

parser.add_argument(
"-q",
"--quantization",
type=str,
nargs='+',
nargs="+",
help=(
"Pass all the quantization techniques for benchmarking: "
+ "int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm"
+ "int8wo, int4wo-<groupsize>, int4wo-<groupsize>-hqq, float8wo"
),
)

Expand Down
12 changes: 8 additions & 4 deletions benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ def __init__(self, m=64, n=32, k=64, dtype=torch.bfloat16):
self.linear1 = torch.nn.Linear(k, n, bias=False).to(dtype)

def example_inputs(self, m=1, device="cuda"):
return (torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),)
return (
torch.randn(m, self.linear1.in_features, dtype=self.dtype, device=device),
)

def forward(self, x):
x = self.linear1(x)
Expand All @@ -17,7 +19,9 @@ def forward(self, x):

def get_default_device() -> str:
return (
"cuda" if torch.cuda.is_available() else
"xpu" if torch.xpu.is_available() else
"cpu"
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)

0 comments on commit e880d25

Please sign in to comment.