Skip to content

Commit

Permalink
Add semi-structured sparsity to hf eval (#576)
Browse files Browse the repository at this point in the history
* Add hf example for semi-structured sparsity

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* updated notebook

* update

* update hf example

* Update version.txt

* update hf_eval changes

* update

* remove notebook and add script
  • Loading branch information
jcaip authored Aug 23, 2024
1 parent 7e69ee3 commit eaf2908
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 5 deletions.
45 changes: 43 additions & 2 deletions benchmarks/benchmark_semi_sparse_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F
from torch.utils import benchmark

from torch.sparse import to_sparse_semi_structured
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
from torchao.sparsity.training.autograd import semi_structured_sparsify

Expand Down Expand Up @@ -118,6 +119,18 @@ def fw(self):
def bw(self):
self.out.backward(self.grad, retain_graph=True)

class SemiSparseLinearOfflineCompressionTest(torch.nn.Module):
def __init__(self, mkn):
super().__init__()
m, k, n = mkn
self.model = torch.nn.Linear(k, n).cuda().half()
self.model.weight = torch.nn.Parameter(to_sparse_semi_structured(self.model.weight))
self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True)
self.grad = torch.randn([m, n], device="cuda", dtype=torch.half)

def fw(self):
self.out = self.model(self.input)

class SemiSparseLinearTest(LinearTest):
def __init__(self, mkn):
super().__init__(mkn)
Expand Down Expand Up @@ -170,8 +183,8 @@ def __init__(self, model_type, batch_size):

if __name__ == "__main__":
print("BENCHMARKING")
parser = argparse.ArgumentParser(description='run semi-structured spares training benchmarks')
parser.add_argument('--mode', type=str, choices=["linear", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
parser = argparse.ArgumentParser(description='run semi-structured sparse training benchmarks')
parser.add_argument('--mode', type=str, choices=["linear", "llama3-8b", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
parser.add_argument('--save', action="store_true", help="save benchmarking results")
args = parser.parse_args()
if args.mode == "linear":
Expand All @@ -198,6 +211,34 @@ def __init__(self, model_type, batch_size):
bw=True,
cuda_graph=True,
blocked_autorange=True)
elif args.mode == "llama3-8b":
functions = {
"dense_linear": LinearTest,
"semi_sparse_linear": SemiSparseLinearOfflineCompressionTest,
}
batch_size = 16
cases = list(
product_dict(
mkn=[
# attn q and o
(batch_size, 4096, 4096),
# attn k and v
(batch_size, 4096, 1024),
# mlp up and gate
(batch_size, 4096, 14336),
# mlp down
(batch_size, 14336, 4096),
],
)
)

df = benchmark_helper(
functions,
cases,
fw=True,
bw=False,
cuda_graph=True,
blocked_autorange=True)

elif args.mode == "vit":
functions = {
Expand Down
29 changes: 26 additions & 3 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
quantize_,
autoquant,
)
from torchao.sparsity import (
sparsify_,
semi_sparse_weight,
)

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True
Expand All @@ -40,10 +44,10 @@ def format_value(value):

print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))

def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, save, batch_size, max_length):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, save, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)

if quantization == "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
Expand All @@ -61,6 +65,24 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
if quantization != "autoquant" and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

if sparsity == "semi_sparse":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
elif sparsity == "semi_sparse_mlp_only":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name and "mlp" in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

if sparsity and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

with torch.no_grad():
result = evaluate(
HFLM(
Expand Down Expand Up @@ -90,10 +112,11 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.save, args.batch_size, args.max_length)
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.sparsity, args.compile, args.save, args.batch_size, args.max_length)
104 changes: 104 additions & 0 deletions tutorials/huggingface_24sparse_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# This script shows how to accelerate an off-the-shelf 2:4 sparse checkpoint
# using pytorch's `to_sparse_semi_structured`

# It takes advantage of the model checkpoints offered by neuralmagic:
# https://huggingface.co/nm-testing/SparseLlama-3-8B-pruned_50.2of4-FP8

import os
import torch
from torchao.sparsity import sparsify_, semi_sparse_weight

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True
torch.set_float32_matmul_precision('high')

def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000


def benchmark(fn, WARMUP=5, N=25):
time_per_batch = []
with torch.no_grad():
# warmup steps
for _ in range(WARMUP):
timed(fn)

# benchmark
for _ in tqdm(range(N)):
with torch.no_grad():
_ , time_sec = timed(fn)
time_per_batch.append(time_sec)

# each time we generate 128 tokens - 7 for the prompt = 121 tokens at a time.
total_time = sum(time_per_batch)
tokens_per_second = 121 * N / total_time
print(f"Total time: {total_time:.3f}s | Tokens/second: {tokens_per_second:.3f}")

# define model and tokenizer
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=torch.float16).cuda()
tokenizer = AutoTokenizer.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4")

# Even though we need to pad the matmul shapes from (1, hidden) @ (hidden, output)
# to (8, hidden) @ (hidden, output) we are still able to achieve speedups on
# the mlp.up and mlp.gate linear layers of the FFN.
def is_mlp_up_or_mlp_gate(mod, name):
return isinstance(mod, torch.nn.Linear) and ('mlp.gate' in name or 'mlp.up' in name)

# apply sparsity
sparsify_(model, semi_sparse_weight(), filter_fn=is_mlp_up_or_mlp_gate)

# Specify the max length (including both the prompt and the response)
# When calling `generate` with `cache_implementation="static" later, this is also used to create a `StaticCache` object
# with sequence length = `max_length`. The longer the more you will re-use it
model.generation_config.max_length = 128
model.generation_config.pad_token_id = tokenizer.eos_token_id
model.generation_config.cache_implementation = "static"

prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# without `torch.compile`: each call takes ~ 5.0 seconds (on A100 80G + torch 2.3)
# Total time: 168.715s | Tokens/second: 17.930
outputs = model.generate(**inputs)
response = tokenizer.batch_decode(outputs)[0]
print(response)

# `torch.compile(model, ...)` is not recommended as you compile callbacks
# and full generate. We recommend compiling only the forward for now.
# "reduce-overhead" will use cudagraphs.
torch._inductor.config.triton.cudagraph_dynamic_shape_warn_limit = None

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

benchmark(lambda: model.generate(**inputs))

# sanity check we get same output as non-compiled model
outputs = model.generate(**inputs)
response = tokenizer.batch_decode(outputs)[0]
print(response)

## Run torch.compile baseline

del model
model = AutoModelForCausalLM.from_pretrained("nm-testing/SparseLlama-3-8B-pruned_50.2of4", torch_dtype=torch.float16).cuda()

model.generation_config.max_length = 128
model.generation_config.pad_token_id = tokenizer.eos_token_id
model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
benchmark(lambda: model.generate(**inputs))

outputs = model.generate(**inputs)
response = tokenizer.batch_decode(outputs)[0]
print(response)

0 comments on commit eaf2908

Please sign in to comment.