Skip to content

Commit

Permalink
Merge branch 'main' into feat/sparse-marlin-gemm-op
Browse files Browse the repository at this point in the history
  • Loading branch information
Diogo-V authored Aug 23, 2024
2 parents c18f6bd + 0ed3090 commit 8699877
Show file tree
Hide file tree
Showing 64 changed files with 869 additions and 494 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # f

Nightly Release
```Shell
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```

From source
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_replace_with_custom_fn_if_matches_filter,
)
import copy
from torchao.utils import unwrap_tensor_subclass

def _int8wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
Expand Down Expand Up @@ -133,15 +134,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
WARMUP = 20
RUNS = 100

torch._dynamo.reset()
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, example_inputs)
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)

torch._dynamo.reset()
m = torch.compile(m, mode='max-autotune', fullgraph=True)
benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)


torch._dynamo.reset()
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
Expand Down
222 changes: 10 additions & 212 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
This is a script to estimate the benefit from converting a `torch.nn.Linear`
layer to float8, by estimating the difference in e2e GPU kernel time between:
Expand Down Expand Up @@ -45,26 +51,10 @@
import torch
import torch.utils.benchmark as benchmark

BYTES_PER_EL_FLOAT8 = 1
BYTES_PER_EL_BF16 = 2

# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
H100_BF16_PEAK_TOPS = 989e12
H100_FP8_PEAK_TOPS = 1979e12

# 2.4 TB per second, custom to Meta's H100 variant
H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12

# based on quick experimental observation with sample large inputs
H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6

# based on previous experience looking at pointwise triton kernels with large inputs,
# which would hit about 2.2k GBPS on Meta's H100 variant
H100_PCT_ACHIEVABLE_MEM_BW = 0.92

# Source: run a triton kernel with a single element read/write on an H100 and
# measure GPU time from the trace
TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001
from torchao.float8.roofline_utils import (
get_gemm_time_sympy,
get_float8_mem_sympy,
)


def benchmark_fn_in_sec(f, *args, **kwargs):
Expand All @@ -78,90 +68,6 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
return measurement.mean


def get_tensor_memory_traffic_bytes(
dim0,
dim1,
scaling_type: str,
fuse_with_prev=False,
model_torch_compile_limitations=False,
):
# assumes input bf16, output f8
numel = dim0 * dim1

if scaling_type == "dynamic":
# x_bf16 = ...
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8

if fuse_with_prev:
kernel_1_rw = 0
else:
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
kernel_1_rw = BYTES_PER_EL_BF16 * numel

# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel

if model_torch_compile_limitations:
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
# has an extra memory read of the input in fp8
# context: https://github.com/pytorch/pytorch/issues/130015
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
else:
tc_adjustment = 0

return kernel_1_rw + kernel_3_rw + tc_adjustment

else:
assert scaling_type == "delayed", "unsupported"
# x_bf16 = ...
# kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
# kernel 3 (not modeled): scale -> reciprocal -> inv_scale

if fuse_with_prev:
kernel_1_r = 0
else:
kernel_1_r = numel * BYTES_PER_EL_BF16
# write twice: once in row major, once in col-major
kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2

if model_torch_compile_limitations:
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
# has an extra memory read of the input in fp8
# context: https://github.com/pytorch/pytorch/issues/130015
tc_adjustment = numel * BYTES_PER_EL_FLOAT8

# https://github.com/pytorch/pytorch/issues/128063
# instead of
# kernel 1: x_bf16 -> max(abs(x)), x_fp8
# kernel 2: not modeled
# kernel 3: not modeled
# we get
# kernel 1: x_bf16 -> max(abs(x))
# reads: same as before
# writes: 0
# ...
# kernel 4: x_bf16, scale -> x_fp8
# reads: numel * BYTES_PER_EL_BF16
# writes: 2 * numel * BYTES_PER_EL_FLOAT8
# Note that assuming worst case, this issue brings the memory
# traffic for delayed scaling to be equal to that of dynamic scaling.
tc_adjustment += (
# subtract writes from kernel 1
-1 * 2 * numel * BYTES_PER_EL_FLOAT8
# add reads for kernel 4
+ numel * BYTES_PER_EL_BF16
# add writes for kernel 4
+ 2 * numel * BYTES_PER_EL_FLOAT8
)
else:
tc_adjustment = 0

return kernel_1_r + kernel_1_w + tc_adjustment


def get_gemm_times_cache(gemm_benchmarks_file: str):
cache = {}
with open(gemm_benchmarks_file, 'r') as f:
Expand All @@ -176,114 +82,6 @@ def get_gemm_times_cache(gemm_benchmarks_file: str):
return cache


def get_gemm_time_sympy(M, K, N, dtype):
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
if dtype is torch.bfloat16:
peak_tops = H100_BF16_PEAK_TOPS
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
peak_tops = H100_FP8_PEAK_TOPS
gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
return gemm_time_s


def get_float8_mem_sympy(
M,
K,
N,
model_torch_compile_limitations: bool = False,
scaling_type_input: str = "dynamic",
scaling_type_weight: str = "dynamic",
scaling_type_grad_output: str = "dynamic",
):

assert scaling_type_input in ("dynamic", "delayed"), "unsupported"
assert scaling_type_weight in ("dynamic", "delayed"), "unsupported"
assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported"

# there are three gemms in the fwd/bwd of a linear:
#
# input @ weight_t = output
# MxK @ KxN => MxN
#
# grad_output @ weight = grad_input
# MxN @ NxK => MxK
#
# input_t @ grad_output = grad_weight
# KxM @ MxN => KxN

#
# forward - output
#
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes(
M, K, scaling_type_input, fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations)
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
K, N, scaling_type_weight, fuse_with_prev=False,
model_torch_compile_limitations=model_torch_compile_limitations)
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem

#
# backward - grad_input
#
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes(
M, N, scaling_type_grad_output, fuse_with_prev=True,
model_torch_compile_limitations=model_torch_compile_limitations)
# already casted, assuming that we save weight from fw to bw
# TODO: model this if FSDP float8 all-gather is on
# TODO: model this if we don't save weight from fw to bw, and recompute instead
gi_fp8_weight_mem = 0

#
# backward - grad_weight
#
# TODO: model this if we don't save fp8 input from fw to bw
gw_fp8_input_t_mem = 0 # already casted
# this should be always 0
gw_fp8_grad_output_mem = 0 # already casted

bwd_fp8_total_mem = \
gi_fp8_grad_output_mem + gi_fp8_weight_mem + \
gw_fp8_input_t_mem + gw_fp8_grad_output_mem
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
fp8_mem_time_s = (
fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
)

# Adjust final estimate for small kernel launches
# note that we do this adjustment here because we are assuming a minimal
# kernel overhead in the units of seconds, and the per-gemm-input memory
# estimations are in the units of bytes.
num_extra_kernels = 0
if scaling_type_input == "dynamic":
# second stage of max-abs reduction
num_extra_kernels += 1
elif scaling_type_input == "delayed":
# second stage of max-abs reduction
num_extra_kernels += 1
# reciprocal of scale
num_extra_kernels += 1
if scaling_type_weight == "dynamic":
# second stage of max-abs reduction
num_extra_kernels += 1
elif scaling_type_weight == "delayed":
# second stage of max-abs reduction
num_extra_kernels += 1
# reciprocal of scale
num_extra_kernels += 1
if scaling_type_grad_output == "dynamic":
# second stage of max-abs reduction
num_extra_kernels += 1
elif scaling_type_grad_output == "delayed":
# second stage of max-abs reduction
num_extra_kernels += 1
# reciprocal of scale
num_extra_kernels += 1

extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC

return fp8_mem_time_s + extra_kernel_overhead_s


def run(
outfile: str,
gemm_time_strategy: str = "benchmarks",
Expand Down
38 changes: 26 additions & 12 deletions benchmarks/quantized_training/pretrain_llama2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
# pip install transformers sentencepiece wandb
# pip install huggingface_hub sentencepiece wandb
#
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
Expand All @@ -9,21 +9,33 @@
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import argparse
from functools import partial
from pathlib import Path

import numpy as np
import torch
import wandb
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from transformers import LlamaConfig, LlamaForCausalLM

from torchao._models.llama.model import ModelArgs, Transformer
from torchao.prototype import low_bit_optim
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
from torchao.quantization.quant_api import quantize_


def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
return model(batch, labels=batch).loss
# hack from fairseq
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
def enable_activation_checkpointing(m: torch.nn.Module):
assert not hasattr(m, "_forward")
m._forward = m.forward
m.forward = partial(checkpoint, m.forward)


def get_loss(model: Transformer, batch: torch.Tensor):
logits = model(batch)[:, :-1].flatten(0, 1)
labels = batch[:, 1:].flatten()
return torch.nn.functional.cross_entropy(logits, labels)


def get_tinystories():
Expand Down Expand Up @@ -91,17 +103,19 @@ def get_tinystories():
if args.seed is not None:
torch.manual_seed(args.seed)

config = LlamaConfig(
hidden_size=args.d_model,
config = ModelArgs(
block_size=args.seq_len,
n_layer=args.depth,
n_head=args.d_model // args.head_dim,
dim=args.d_model,
intermediate_size=args.ffn_size,
num_hidden_layers=args.depth,
num_attention_heads=args.d_model // args.head_dim,
max_position_embeddings=args.seq_len,
use_cache=False,
)
model = LlamaForCausalLM(config).bfloat16().cuda()
model = Transformer(config).bfloat16().cuda()
with torch.device("cuda"):
model.setup_caches(args.batch_size, args.seq_len, training=True)
if args.activation_checkpointing:
model.gradient_checkpointing_enable()
for layer in model.layers:
enable_activation_checkpointing(layer)
if args.quantize == "int8_weight_only":
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
elif args.quantize is not None:
Expand Down
2 changes: 1 addition & 1 deletion scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
from huggingface_hub import snapshot_download
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
try:
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
Expand Down
17 changes: 17 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def test_weights_only(self):
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_device(self):
from torchao.quantization import quantize_
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()



if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 8699877

Please sign in to comment.