Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute quantiles for memory usage #187

Merged
merged 1 commit into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions benchmark/benchmark_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton
from torch.nn import CrossEntropyLoss
from utils import _test_memory, get_current_file_directory
from utils import QUANTILES, _test_memory, get_current_file_directory

from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

Expand Down Expand Up @@ -58,16 +58,14 @@ def fwd():
else:
return torch_ce(_input, target)

quantiles = [0.5, 0.2, 0.8]

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif mode == "backward":
y = fwd()

ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
Expand All @@ -77,7 +75,7 @@ def full():
y = fwd()
y.backward()

ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return ms, min_ms, max_ms


Expand Down Expand Up @@ -128,8 +126,8 @@ def full():
y = fwd()
y.backward()

mem = _test_memory(full)
return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_cross_entropy_wrapper():
Expand Down
12 changes: 5 additions & 7 deletions benchmark/benchmark_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton
from torch.nn import Embedding
from utils import _test_memory, get_current_file_directory
from utils import QUANTILES, _test_memory, get_current_file_directory

from liger_kernel.transformers.experimental.embedding import LigerEmbedding

Expand Down Expand Up @@ -136,12 +136,10 @@ def full():
output = fwd()
output.backward(torch.randn_like(output))

quantiles = [0.5, 0.2, 0.8]

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif mode == "full":
ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return ms, min_ms, max_ms


Expand Down Expand Up @@ -208,8 +206,8 @@ def full():
output = fwd()
output.backward(torch.randn_like(output))

mem = _test_memory(full)
return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_speed_embedding_wrapper():
Expand Down
14 changes: 6 additions & 8 deletions benchmark/benchmark_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import triton
from utils import _test_memory, get_current_file_directory
from utils import QUANTILES, _test_memory, get_current_file_directory

from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
Expand Down Expand Up @@ -109,8 +109,8 @@ def full():
y = fwd()
y.backward()

mem = _test_memory(full, _iter=10)
return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_cross_entropy_wrapper():
Expand Down Expand Up @@ -230,16 +230,14 @@ def fwd():
elif provider == "huggingface":
return torch_lm_head_ce(_input, target)

quantiles = [0.5, 0.2, 0.8]

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(fwd, quantiles=QUANTILES, rep=100)
elif mode == "backward":
y = fwd()

ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[_input],
rep=100,
)
Expand All @@ -249,7 +247,7 @@ def full():
y = fwd()
y.backward()

ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=quantiles, rep=100)
ms, min_ms, max_ms = triton.testing.do_bench(full, quantiles=QUANTILES, rep=100)
return ms, min_ms, max_ms


Expand Down
20 changes: 11 additions & 9 deletions benchmark/benchmark_geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaMLP
from utils import (
QUANTILES,
_print_memory_banner,
_print_speed_banner,
_test_memory,
Expand Down Expand Up @@ -55,7 +56,6 @@ def bench_speed_geglu(N, dtype, provider, mode="forward", device="cuda"):

# initialize input
x = torch.randn(*x_shape, device=device, dtype=dtype, requires_grad=True)
quantiles = [0.5, 0.2, 0.8]

if provider == "liger":
layer = LigerGEGLUMLP(config=LLAMA_CONFIG).to(device).to(dtype)
Expand All @@ -69,14 +69,14 @@ def fwd():

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
fwd, quantiles=quantiles, grad_to_none=[x], rep=10
fwd, quantiles=QUANTILES, grad_to_none=[x], rep=10
)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(do, retain_graph=True),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[x],
rep=10,
)
Expand All @@ -87,10 +87,10 @@ def full():
y.backward(torch.randn_like(y), retain_graph=True)

ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[x], rep=10
full, quantiles=QUANTILES, grad_to_none=[x], rep=10
)

return ms, max_ms, min_ms
return ms, min_ms, max_ms


def benchmark_speed_geglu_wrapper():
Expand Down Expand Up @@ -135,15 +135,17 @@ def full():
y.backward(torch.randn_like(y), retain_graph=True)

if mode == "forward":
mem = _test_memory(fwd)
mem, min_mem, max_mem = _test_memory(fwd, quantiles=QUANTILES)
elif mode == "backward":
do = torch.randn_like(x)
y = fwd()
mem = _test_memory(lambda: y.backward(do, retain_graph=True))
mem, min_mem, max_mem = _test_memory(
lambda: y.backward(do, retain_graph=True), quantiles=QUANTILES
)
else:
mem = _test_memory(full)
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)

return mem / 2**20
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_geglu_wrapper():
Expand Down
16 changes: 7 additions & 9 deletions benchmark/benchmark_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import triton
from utils import _print_memory_banner, _print_speed_banner, _test_memory
from utils import QUANTILES, _print_memory_banner, _print_speed_banner, _test_memory

from liger_kernel.transformers.layer_norm import LigerLayerNorm

Expand Down Expand Up @@ -44,7 +44,6 @@ def bench_speed_layer_norm(M, N, dtype, provider, mode, eps=1e-6, device="cuda")
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]

def y_fwd():
if provider == "liger":
Expand All @@ -54,13 +53,13 @@ def y_fwd():

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
y_fwd, quantiles=quantiles, grad_to_none=[x], rep=500
y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
)
elif mode == "backward":
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
Expand All @@ -71,10 +70,10 @@ def full():
y.backward(dy, retain_graph=True)

ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[x], rep=500
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
)

return ms, max_ms, min_ms
return ms, min_ms, max_ms


def benchmark_speed_layer_norm_wrapper():
Expand Down Expand Up @@ -124,9 +123,8 @@ def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

mem = _test_memory(full)

return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_layer_norm_wrapper():
Expand Down
16 changes: 7 additions & 9 deletions benchmark/benchmark_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
import triton
from utils import _print_memory_banner, _print_speed_banner, _test_memory
from utils import QUANTILES, _print_memory_banner, _print_speed_banner, _test_memory

from liger_kernel.transformers.rms_norm import LigerRMSNorm

Expand Down Expand Up @@ -74,7 +74,6 @@ def bench_speed_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
x = torch.randn(x_shape, dtype=dtype, device="cuda")
dy = torch.randn_like(x)
x.requires_grad_(True)
quantiles = [0.5, 0.2, 0.8]

# utility functions

Expand All @@ -88,13 +87,13 @@ def y_fwd():

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
y_fwd, quantiles=quantiles, grad_to_none=[x], rep=500
y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
)
elif mode == "backward":
y = y_fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
Expand All @@ -105,10 +104,10 @@ def full():
y.backward(dy, retain_graph=True)

ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[x], rep=500
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
)

return ms, max_ms, min_ms
return ms, min_ms, max_ms


def benchmark_speed_rms_norm_wrapper():
Expand Down Expand Up @@ -159,9 +158,8 @@ def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

mem = _test_memory(full)

return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_rms_norm_wrapper():
Expand Down
15 changes: 7 additions & 8 deletions benchmark/benchmark_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
apply_rotary_pos_emb,
)
from utils import (
QUANTILES,
_print_memory_banner,
_print_speed_banner,
_test_memory,
Expand Down Expand Up @@ -77,8 +78,6 @@ def bench_speed_rope(total_hidden_size, seq_len, provider, mode, dtype):
pos_ids = torch.arange(seq_len, device="cuda", dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)

quantiles = [0.5, 0.2, 0.8]

def fwd():
if provider == "liger":
return liger_rotary_pos_emb(q, k, cos, sin, pos_ids)
Expand All @@ -89,15 +88,15 @@ def fwd():

if mode == "forward":
ms, min_ms, max_ms = triton.testing.do_bench(
fwd, quantiles=quantiles, grad_to_none=[q, k], rep=400
fwd, quantiles=QUANTILES, grad_to_none=[q, k], rep=400
)
elif mode == "backward":
q_out, k_out = fwd()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.autograd.grad(
(q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
),
quantiles=quantiles,
quantiles=QUANTILES,
grad_to_none=[q, k],
rep=400,
)
Expand All @@ -108,9 +107,9 @@ def full():
torch.autograd.grad((q_out, k_out), (q, k), (dq, dk), allow_unused=True)

ms, min_ms, max_ms = triton.testing.do_bench(
full, quantiles=quantiles, grad_to_none=[q, k], rep=400
full, quantiles=QUANTILES, grad_to_none=[q, k], rep=400
)
return ms, max_ms, min_ms
return ms, min_ms, max_ms


def benchmark_speed_rope_wrapper():
Expand Down Expand Up @@ -159,8 +158,8 @@ def full():
(q_out, k_out), (q, k), (dq, dk), allow_unused=True, retain_graph=True
)

mem = _test_memory(full)
return mem / 2**20
mem, min_mem, max_mem = _test_memory(full, quantiles=QUANTILES)
return (mem / 2**20, min_mem / 2**20, max_mem / 2**20)


def benchmark_memory_rope_wrapper():
Expand Down
Loading
Loading