Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composing autoquant with compile #175

Merged
merged 17 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/doc_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jobs:
run: |
python -m pip install torch
python -m pip install -e .
pip install -r dev-requirements.txt
cd docs
python -m pip install -r requirements.txt
- name: Build docs
Expand Down
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ torch._inductor.config.use_mixed_mm = True
model = torch.nn.Sequential(torch.nn.Linear(32, 64)).cuda().to(torch.bfloat16)
input = torch.randn(32,32, dtype=torch.bfloat16, device='cuda')

# perform autoquantization
torchao.autoquant(model, (input))

# compile the model to recover performance
model = torch.compile(model, mode='max-autotune')
model(input)
# perform autoquantization and compilation
q_model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
q_model(input)
```

### Sparsity
Expand Down
131 changes: 131 additions & 0 deletions benchmarks/dora/bench_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
from bitsandbytes.nn import Linear4bit
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear

from prototypes.dora.dora_layer import BNBDoRALinear, HQQDoRALinear
from prototypes.dora.kernels.matmul import triton_mm
from prototypes.dora.kernels.smallk import triton_mm_small_k


def make_lora_weights(ranks, in_features, out_features, dtype):
As = [torch.randn(rank, in_features, device="cuda", dtype=dtype) for rank in ranks]
Bs = [torch.randn(out_features, rank, device="cuda", dtype=dtype) for rank in ranks]
return As, Bs


def make_dora_source_and_magnitude(in_features, out_features, dtype):
source = torch.randn(out_features, in_features, device="cuda", dtype=dtype)
magnitude = torch.randn(out_features, device="cuda", dtype=dtype)
return source, magnitude


def make_inputs(batch_sizes, seqlen, in_features, dtype):
xs = [
torch.randn(bs * seqlen, in_features, device="cuda", dtype=dtype)
for bs in batch_sizes
]
return xs


def make_weights(batch_sizes, in_features, out_features, dtype):
weights = [
torch.randn(in_features, out_features, device="cuda", dtype=dtype)
for _ in range(len(batch_sizes))
]
return weights


def make_epilogue_sources(batch_sizes, seqlen, out_features, dtype):
epilogue_sources = [
torch.randn(bs * seqlen, out_features, device="cuda", dtype=dtype)
for bs in batch_sizes
]
return epilogue_sources


def make_epilogue_scales(batch_sizes, out_features, dtype):
epilogue_scales = [
torch.randn(out_features, device="cuda", dtype=dtype)
for _ in range(len(batch_sizes))
]
return epilogue_scales


def dora_colnorm_ref(
A: torch.Tensor,
B: torch.Tensor,
base_weight: torch.Tensor,
magnitude_vector: torch.Tensor,
):
column_norm = (base_weight + B @ A).norm(p=2, dim=1)
return magnitude_vector / column_norm


def dora_mm_epilogue_ref(
A: torch.Tensor,
B: torch.Tensor,
epilogue_source: torch.Tensor,
epilogue_scale: torch.Tensor,
):
out = (A @ B + epilogue_source) * epilogue_scale[None, :]
return out


def dora_ref(x, w, lora_A, lora_B, magnitude_vector):
# (bs x seq_len x out_features) = (bs x seq_len x in_features) @ (in_features x rank) @ (rank x out_features)
lora_out = (x @ lora_A.T) @ lora_B.T
# (out_features)
magnitude_scale = dora_colnorm_ref(lora_A, lora_B, w, magnitude_vector)
# (bs x seq_len x out_features)
dora_out_ref = dora_mm_epilogue_ref(x, w, lora_out, magnitude_scale)
return dora_out_ref


def dora_triton(x, w, lora_A, lora_B, magnitude_vector):
lora_out = (x @ lora_A.T) @ lora_B.T
magnitude_scale = triton_mm_small_k(
lora_B,
lora_A,
epilogue_norm=True,
source=w,
magnitude=magnitude_vector,
store_acc=False,
)
dora_out = triton_mm(x, w, epilogue_source=lora_out, epilogue_scale=magnitude_scale)
return dora_out


def setup_dora_base_layers(layer_type, in_features, out_features, dtype):
if "bnb" in layer_type:
# BitsandBytes
base_layer = Linear4bit(
input_features=in_features,
output_features=out_features,
bias=False,
quant_type="nf4",
compute_dtype=dtype,
).cuda()
base_layer.quant_state.dtype = base_layer.compute_dtype
dora_cls = BNBDoRALinear
elif "hqq" in layer_type:
# HQQ
quant_config = BaseQuantizeConfig(
nbits=4,
group_size=64,
quant_zero=False,
quant_scale=False,
offload_meta=True,
view_as_float=True,
)
linear = torch.nn.Linear(
in_features, out_features, dtype=dtype, bias=False
).cuda()
base_layer = HQQLinear(
linear,
quant_config,
compute_dtype=dtype,
)
dora_cls = HQQDoRALinear
else:
raise ValueError(f"Unknown layer type: {layer_type}")
return base_layer, dora_cls
Loading
Loading