Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What kind of layers are optimized by torchao on a RTX 4090? #1805

Open
naiveen opened this issue Mar 1, 2025 · 1 comment
Open

What kind of layers are optimized by torchao on a RTX 4090? #1805

naiveen opened this issue Mar 1, 2025 · 1 comment
Labels
performance question Further information is requested

Comments

@naiveen
Copy link

naiveen commented Mar 1, 2025

I am trying to quantize a model and I am running this on a 4090. Since many of the available quantization benchmarks are done on higher gpus, I am trying to establish a baseline perfromance gain I can expect from quantization.

I tried the tutorial at torchao_demo on a gpu and it worked great. My model has similar kind of transformer layers with q, k, v projections but I am not able to see the same kind of performance with a large chunk of aten::_copy() operations in profile log.

To debug, I wanted to benchmark on a single linear layer as the majority of modified layers seem to be of this type. But I am not able to see any performance gain in this experiment of mine. I would appreciate if I can get more context into the specific layers that gets optimized by torchao.

'''
    https://github.com/ethanshenley/PyTorch-Conference-Recipes/blob/main/torchao_demo.ipynb
'''
import gc
import psutil
import torch
import torch.nn as nn
import time

from torchao.quantization import quantize_, int8_weight_only,float8_weight_only


device = "cuda:0"
def get_memory_usage():
    return psutil.Process().memory_info().rss / 1024 / 1024  # in MB

def run_inference(model, inputs, num_runs=10):
    start_time = time.time()
    for i in range(num_runs):
        with torch.no_grad():
            outputs = model(inputs[i].squeeze())
    torch.cuda.synchronize(device)
    end_time = time.time()
    return (end_time - start_time) / num_runs

# Load model and tokenizer
bsz = 16
n_runs = 100
for sz in range(1024, 20480, 1024):
    print('====================================================')
    print(f"Running with linear layer of size {sz}...")
    model = nn.Linear(sz, sz).to(device)
    inputs = torch.randn(n_runs, bsz, sz).to(device)

    print("\nRunning baseline model...")
    baseline_memory = get_memory_usage()
    baseline_time = run_inference(model, inputs, n_runs)
    print(f"Baseline - Time: {baseline_time:.4f}s, Memory: {baseline_memory:.2f}MB")


    print("\nRunning int8 weight-only quantized model...")
    model_int8 = nn.Linear(sz, sz).to(device)
    quantize_(model_int8, int8_weight_only())
    int8_memory = get_memory_usage()
    int8_time = run_inference(model_int8, inputs, n_runs)
    print(f"Int8 Weight-Only - Time: {int8_time:.4f}s, Memory: {int8_memory:.2f}MB")

    print("\nRunning fp8 weight-only quantized model...")
    model_fp8 = nn.Linear(sz, sz).to(device)
    quantize_(model_fp8, float8_weight_only())  
    fp8_memory = get_memory_usage()
    fp8_time = run_inference(model, inputs, n_runs)
    print(f"fp8 Weight-Only  - Time: {fp8_time:.4f}s, Memory: {fp8_memory:.2f}MB")


    print("\nPerformance Improvements:")
    print(f"Int8 weight-only speedup: {baseline_time / int8_time:.2f}x")
    print(f"Int8 weight-only memory reduction: {baseline_memory / int8_memory:.2f}x")
    print(f"fp8 weight-only speedup: {baseline_time / fp8_time:.2f}x")
    print(f"fp8 weight-only memory reduction: {baseline_memory / fp8_memory:.2f}x")

    del model, model_int8, model_fp8, inputs
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize(device)
@supriyar
Copy link
Contributor

supriyar commented Mar 1, 2025

torchao quantizes Linear layers. However depending on batch-size and layer shape you may see different levels of performance improvements for different techniques. Eg. weight-only works best for bs=1 while dynamic quant is preferred for bs=n scenarios.

Most of our benchmarks are run on A100 or H100. But you can try the gemlite kernels, details https://github.com/pytorch/ao/tree/main/torchao/quantization#gemlite-triton which are expected to be optimized for 4090. cc @mobicham

@supriyar supriyar added question Further information is requested performance labels Mar 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants