Skip to content

Commit

Permalink
Move torchao/_models to benchmarks/_models (#1784)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Mar 4, 2025
1 parent 2c2a590 commit 81a2813
Show file tree
Hide file tree
Showing 91 changed files with 445 additions and 425 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/dashboard_perf_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ jobs:
mkdir -p ${{ runner.temp }}/benchmark-results
# llama3 - compile baseline
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
# llama3 - autoquant
${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json
# skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407
# # SAM
# ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main
# # SAM compile baselilne
# ${CONDA_RUN} sh torchao/_models/sam/setup.sh
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# ${CONDA_RUN} sh benchmarks/_models/sam/setup.sh
# ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json
# SAM 2.1
# ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o

### Post Training Quantization

Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)
Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/benchmarks/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py)

For inference, we have the option of
1. Quantize only the weights: works best for memory bound models
Expand Down Expand Up @@ -52,7 +52,7 @@ We also provide a developer facing API so you can implement your own quantizatio

We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference.

In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md)
In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](benchmarks/_models/llama/README.md)

## Training

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and follow the steps to gain access.
Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to
download and convert the model weights

once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
once done you can execute benchmarks from the benchmarks/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation
directly using `generate.py` or `eval.py`.

## KV Cache Quantization - Memory Efficient Inference
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 12 additions & 9 deletions torchao/_models/llama/eval.py → benchmarks/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from typing import List, Optional

import torch
from generate import (
_load_model,
device_sync,
)
from tokenizer import get_tokenizer

import torchao
from torchao._models.llama.model import prepare_inputs_for_model
from benchmarks._models.llama.model import prepare_inputs_for_model
from benchmarks._models.utils import (
_load_model,
)
from torchao.quantization import (
PerRow,
PerTensor,
Expand All @@ -28,7 +27,11 @@
quantize_,
uintx_weight_only,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
device_sync,
unwrap_tensor_subclass,
)


def run_evaluation(
Expand Down Expand Up @@ -120,7 +123,7 @@ def run_evaluation(
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import MultiTensorInputRecorder
from benchmarks._models._eval import MultiTensorInputRecorder
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer

groupsize = int(quantization.split("-")[-2])
Expand Down Expand Up @@ -172,7 +175,7 @@ def run_evaluation(
if "autoround" in quantization:
from transformers import AutoTokenizer

from torchao._models.llama.model import TransformerBlock
from benchmarks._models.llama.model import TransformerBlock
from torchao.prototype.autoround.autoround_llm import (
quantize_model_with_autoround_,
)
Expand Down Expand Up @@ -242,7 +245,7 @@ def run_evaluation(
with torch.no_grad():
print("Running evaluation ...")
# avoid circular imports
from torchao._models._eval import TransformerEvalWrapper
from benchmarks._models._eval import TransformerEvalWrapper

TransformerEvalWrapper(
model=model.to(device),
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,30 @@
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Tuple
from typing import Optional

import torch
import torch._dynamo.config
import torch._inductor.config

import torchao
from torchao._models.utils import (
from benchmarks._models.utils import (
_load_model,
decode_n_tokens,
decode_one_token,
encode_tokens,
get_arch_name,
prefill,
write_json_result_local,
write_json_result_ossci,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
default_device,
device_sync,
get_model_size_in_bytes,
)

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
torch.backends.cuda.enable_cudnn_sdp(True)
Expand Down Expand Up @@ -49,97 +59,12 @@ def device_timer(device):
print(f"device={device} is not yet suppported")


def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")


default_device = (
"cuda"
if torch.cuda.is_available()
else "xpu"
if torch.xpu.is_available()
else "cpu"
)

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from torchao._models.llama.model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer


def multinomial_sample_one_no_sync(
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
logits = logits / max(temperature, 1e-5)

if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs


def prefill(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)[0]


def decode_one_token(
model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)


def decode_n_tokens(
model: Transformer,
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
callback=lambda _: _,
**sampling_kwargs,
):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token

return new_tokens, new_probs
from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model
from benchmarks._models.llama.tokenizer import get_tokenizer


def model_forward(model, x, input_pos):
Expand Down Expand Up @@ -230,25 +155,6 @@ def generate(
return seq


def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
return torch.tensor(tokens, dtype=torch.int, device=device)


def _load_model(checkpoint_path, device, precision):
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
if "model" in checkpoint and "stories" in str(checkpoint_path):
checkpoint = checkpoint["model"]
with torch.device("meta"):
model = Transformer.from_name(checkpoint_path.parent.name)
model.load_state_dict(checkpoint, assign=True)
model = model.to(device=device, dtype=precision)

return model.eval()


B_INST, E_INST = "[INST]", "[/INST]"


Expand Down Expand Up @@ -476,7 +382,7 @@ def ffn_or_attn_only(mod, fqn):
filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding),
)
elif quantization.startswith("awq"):
from torchao._models._eval import TransformerEvalWrapper
from benchmarks._models._eval import TransformerEvalWrapper
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3

if not TORCH_VERSION_AT_LEAST_2_3:
Expand Down Expand Up @@ -575,8 +481,8 @@ def ffn_or_attn_only(mod, fqn):
model, float8_dynamic_activation_float8_weight(granularity=granularity)
)
elif "autoquant_v2" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from benchmarks._models._eval import InputRecorder
from benchmarks._models.llama.model import prepare_inputs_for_model
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2

calibration_seq_length = 256
Expand Down Expand Up @@ -665,8 +571,8 @@ def ffn_or_attn_only(mod, fqn):
# do autoquantization
model.finalize_autoquant()
elif "autoquant" in quantization:
from torchao._models._eval import InputRecorder
from torchao._models.llama.model import prepare_inputs_for_model
from benchmarks._models._eval import InputRecorder
from benchmarks._models.llama.model import prepare_inputs_for_model

calibration_seq_length = 256
inputs = (
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@
import torch
from torch.nn.attention import SDPBackend

from torchao._models.llama.model import Transformer
from torchao._models.llama.tokenizer import get_tokenizer
from benchmarks._models.llama.model import Transformer
from benchmarks._models.llama.tokenizer import get_tokenizer
from torchao.prototype.profiler import (
CUDADeviceSpec,
TransformerPerformanceCounter,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from metrics import calculate_miou, create_result_entry

import torchao
from torchao._models.utils import (
from benchmarks._models.utils import (
get_arch_name,
write_json_result_local,
write_json_result_ossci,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from hydra.core.global_hydra import GlobalHydra

if not GlobalHydra.instance().is_initialized():
initialize_config_module("torchao._models.sam2", version_base="1.2")
initialize_config_module("benchmarks._models.sam2", version_base="1.2")
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore

from torchao._models.sam2.modeling.sam2_base import SAM2Base
from torchao._models.sam2.sam2_image_predictor import SAM2ImagePredictor
from torchao._models.sam2.utils.amg import (
from benchmarks._models.sam2.modeling.sam2_base import SAM2Base
from benchmarks._models.sam2.sam2_image_predictor import SAM2ImagePredictor
from benchmarks._models.sam2.utils.amg import (
MaskData,
_mask_to_rle_pytorch_2_0,
_mask_to_rle_pytorch_2_1,
Expand All @@ -33,7 +33,7 @@
uncrop_masks,
uncrop_points,
)
from torchao._models.sam2.utils.misc import (
from benchmarks._models.sam2.utils.misc import (
crop_image,
get_image_size,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hydra.utils import instantiate
from omegaconf import OmegaConf

from torchao._models import sam2
from benchmarks._models import sam2

# Check if the user is running Python from the parent directory of the sam2 repo
# (i.e. the directory where this repo is cloned into) -- this is not supported since
Expand Down Expand Up @@ -106,7 +106,7 @@ def build_sam2_video_predictor(
**kwargs,
):
hydra_overrides = [
"++model._target_=torchao._models.sam2.sam2_video_predictor.SAM2VideoPredictor",
"++model._target_=benchmarks._models.sam2.sam2_video_predictor.SAM2VideoPredictor",
]
if apply_postprocessing:
hydra_overrides_extra = hydra_overrides_extra.copy()
Expand Down
Loading

0 comments on commit 81a2813

Please sign in to comment.