Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semi-structured sparsity to hf eval #576

Merged
merged 12 commits into from
Aug 23, 2024
45 changes: 43 additions & 2 deletions benchmarks/benchmark_semi_sparse_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch.nn.functional as F
from torch.utils import benchmark

from torch.sparse import to_sparse_semi_structured
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
from torchao.sparsity.training.autograd import semi_structured_sparsify

Expand Down Expand Up @@ -118,6 +119,18 @@ def fw(self):
def bw(self):
self.out.backward(self.grad, retain_graph=True)

class SemiSparseLinearOfflineCompressionTest(torch.nn.Module):
def __init__(self, mkn):
super().__init__()
m, k, n = mkn
self.model = torch.nn.Linear(k, n).cuda().half()
self.model.weight = torch.nn.Parameter(to_sparse_semi_structured(self.model.weight))
self.input = torch.randn([m, k], device='cuda', dtype=torch.half, requires_grad=True)
self.grad = torch.randn([m, n], device="cuda", dtype=torch.half)

def fw(self):
self.out = self.model(self.input)

class SemiSparseLinearTest(LinearTest):
def __init__(self, mkn):
super().__init__(mkn)
Expand Down Expand Up @@ -170,8 +183,8 @@ def __init__(self, model_type, batch_size):

if __name__ == "__main__":
print("BENCHMARKING")
parser = argparse.ArgumentParser(description='run semi-structured spares training benchmarks')
parser.add_argument('--mode', type=str, choices=["linear", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
parser = argparse.ArgumentParser(description='run semi-structured sparse training benchmarks')
parser.add_argument('--mode', type=str, choices=["linear", "llama3-8b", "vit"], help='nn.Linear/ViT-e2e benchmarking', default="vit")
parser.add_argument('--save', action="store_true", help="save benchmarking results")
args = parser.parse_args()
if args.mode == "linear":
Expand All @@ -198,6 +211,34 @@ def __init__(self, model_type, batch_size):
bw=True,
cuda_graph=True,
blocked_autorange=True)
elif args.mode == "llama3-8b":
functions = {
"dense_linear": LinearTest,
"semi_sparse_linear": SemiSparseLinearOfflineCompressionTest,
}
batch_size = 16
cases = list(
product_dict(
mkn=[
# attn q and o
(batch_size, 4096, 4096),
# attn k and v
(batch_size, 4096, 1024),
# mlp up and gate
(batch_size, 4096, 14336),
# mlp down
(batch_size, 14336, 4096),
],
)
)

df = benchmark_helper(
functions,
cases,
fw=True,
bw=False,
cuda_graph=True,
blocked_autorange=True)

elif args.mode == "vit":
functions = {
Expand Down
30 changes: 26 additions & 4 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
quantize_,
autoquant,
)
from torchao.sparsity import (
sparsify_,
semi_sparse_weight,
)

torch._inductor.config.force_fuse_int_mm_with_mul = True
torch._inductor.config.fx_graph_cache = True
Expand All @@ -40,13 +44,14 @@ def format_value(value):

print(tabulate(main_table, headers=['Task', 'Metrics'], tablefmt='grid'))

def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length):
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, sparsity, compile, batch_size, max_length):

tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
model = AutoModelForCausalLM.from_pretrained(repo_id).to(dtype=precision, device=device)

if compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)
model = torch.compile(model, mode="max-autotune", fullgraph=True)


if quantization == "int8dq":
quantize_(model, int8_dynamic_activation_int8_weight())
Expand All @@ -57,6 +62,22 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "autoquant":
model = autoquant(model.to(device=device))

if sparsity == "semi_sparse":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)
elif sparsity == "semi_sparse_mlp_only":
def all_linear(mod, name):
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name and "mlp" in name:
return True
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

with torch.no_grad():
result = evaluate(
HFLM(
Expand All @@ -80,9 +101,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

args = parser.parse_args()
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length)
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.sparsity, args.compile, args.batch_size, args.max_length)
Loading
Loading