Skip to content

Commit

Permalink
[sizing] fixes and enhancements (#26)
Browse files Browse the repository at this point in the history
* fixes and enhancements

* deps

* update

* fix

* sort

* useful gist

* add --verbose

* fix

* fix

* Move sample results

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
stas00 and Quentin-Anthony authored Feb 16, 2024
1 parent ef6afa8 commit 3d3011e
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 48 deletions.
16 changes: 16 additions & 0 deletions benchmarks/sizing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ The intent of these benchmarks is to measure the throughput of Generalized Matri
- The performance characteristics of GEMMs and BMMs on their GPU architecture.
- How these GEMMs and BMMs form transformer layers.

## Dependencies

First, install the required packages:
```
pip install -r requirements.txt
```


There are three scripts within `benchmarks/sizing` that can be run:

## GEMM Benchmarks
Expand Down Expand Up @@ -32,6 +40,8 @@ options:
--cuda_device CUDA_DEVICE
The cuda device to run the benchmark on
--output_file OUTPUT_FILE
--verbose, --no-verbose
log to stdout besides output_file? (default: True)
```

## BMM Benchmarks
Expand Down Expand Up @@ -62,8 +72,12 @@ options:
--cuda_device CUDA_DEVICE
The cuda device to run the benchmark on
--output_file OUTPUT_FILE
--verbose, --no-verbose
log to stdout besides output_file? (default: True)
```

Note that `bmm` with `b=1` performs about the same as `mm` starting from largish dimensions [see](https://gist.github.com/malfet/6a17156d7f5663b8b12054a1beff3fe1).

## Transformer Layer Benchmarks
`transformer_flops.py` measures throughput of a transformer layer or of each block of a transformer layer.
```
Expand Down Expand Up @@ -121,6 +135,8 @@ options:
--cuda_device CUDA_DEVICE
The cuda device to run the benchmark on
--output_file OUTPUT_FILE
--verbose, --no-verbose
log to stdout besides output_file? (default: True)
```

## Output Files
Expand Down
22 changes: 13 additions & 9 deletions benchmarks/sizing/bmm_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import numpy as np
import sys
import argparse
import os

from utils import benchmark_bmm
from utils import Tee, benchmark_bmm

file_dir = os.path.abspath(os.path.dirname(__file__))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -28,7 +30,8 @@
parser.add_argument("--num_iterations", type=int, default=200, help='The number of iterations used to benchmark each BMM')
parser.add_argument("--num_warmup_iterations", type=int, default=50, help='The number of warmup iterations')
parser.add_argument("--cuda_device", type=int, default=0, help="The cuda device to run the benchmark on")
parser.add_argument("--output_file", type=str, default="../results/bmm.out")
parser.add_argument("--output_file", type=str, default=f"{file_dir}/results/bmm.out")
parser.add_argument("--verbose", default=True, action=argparse.BooleanOptionalAction, help='log to stdout besides output_file?')
args = parser.parse_args()

b = args.b
Expand All @@ -52,11 +55,12 @@
# set cuda device
torch.cuda.set_device(f"cuda:{args.cuda_device}")

sys.stdout = Tee(args.output_file, args.verbose)

# loop through all sizes to benchmark
with open(args.output_file, 'w') as sys.stdout:
for B in b:
for M in m:
for N in n:
for K in k:
benchmark_bmm(B, M, N, K, "bmm", args.num_iterations, args.num_warmup_iterations)
print("-" * 80)
for B in b:
for M in m:
for N in n:
for K in k:
benchmark_bmm(B, M, N, K, "bmm", args.num_iterations, args.num_warmup_iterations)
print("-" * 80)
21 changes: 14 additions & 7 deletions benchmarks/sizing/mm_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import sys
import numpy as np
import argparse
import os

from utils import benchmark_mm
from utils import Tee, benchmark_mm

file_dir = os.path.abspath(os.path.dirname(__file__))

if __name__ == '__main__':
parser = argparse.ArgumentParser()
Expand All @@ -23,7 +26,8 @@
parser.add_argument("--num_iterations", type=int, default=200, help='The number of iterations used to benchmark each GEMM')
parser.add_argument("--num_warmup_iterations", type=int, default=50, help='The number of warmup iterations')
parser.add_argument("--cuda_device", type=int, default=0, help="The cuda device to run the benchmark on")
parser.add_argument("--output_file", type=str, default="../results/mm.out")
parser.add_argument("--output_file", type=str, default=f"{file_dir}/results/mm.out")
parser.add_argument("--verbose", default=True, action=argparse.BooleanOptionalAction, help='log to stdout besides output_file?')
args = parser.parse_args()

m = args.m
Expand All @@ -43,9 +47,12 @@
# set cuda device
torch.cuda.set_device(f"cuda:{args.cuda_device}")

sys.stdout = Tee(args.output_file, args.verbose)

# loop through all sizes to benchmark
with open(args.output_file, 'w') as sys.stdout:
for M in m:
for N in n:
for K in k:
benchmark_mm(M, N, K, args.num_iterations, args.num_warmup_iterations)
for M in m:
for N in n:
for K in k:
benchmark_mm(M, N, K, args.num_iterations, args.num_warmup_iterations)


8 changes: 8 additions & 0 deletions benchmarks/sizing/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

deepspeed
numpy
pyyaml
sentencepiece
tokenizers
torch
transformers
File renamed without changes.
File renamed without changes.
File renamed without changes.
68 changes: 36 additions & 32 deletions benchmarks/sizing/transformer_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from megatron.model.gpt2_model import gpt2_attention_mask_func as attention_mask_func
from megatron.model.word_embeddings import Embedding

file_dir = os.path.abspath(os.path.dirname(__file__))

# benchmarks the individual components of the transformer. Will only be used if --layers is specified and will only benchmark the layers specified
def benchmark_transformer_from_mm_and_bmm(args, configuration, seq_length, global_batch_size, num_iterations, num_warmup_iterations):

Expand Down Expand Up @@ -227,7 +229,8 @@ def benchmark_transformer(c_args,configuration, seq_length, global_batch_size, n
parser.add_argument("--num_iterations", type=int, default=200, help='The number of iterations used to benchmark each BMM')
parser.add_argument("--num_warmup_iterations", type=int, default=50, help='The number of warmup iterations')
parser.add_argument("--cuda_device", type=int, default=0, help="The cuda device to run the benchmark on")
parser.add_argument("--output_file", type=str, default="../results/transformer.out")
parser.add_argument("--output_file", type=str, default=f"{file_dir}/results/mm.out")
parser.add_argument("--verbose", default=True, action=argparse.BooleanOptionalAction, help='log to stdout besides output_file?')
args = parser.parse_args()

h = args.hidden_size
Expand Down Expand Up @@ -261,35 +264,36 @@ def benchmark_transformer(c_args,configuration, seq_length, global_batch_size, n
global_batch_size = np.arange(start,stop,step)

torch.cuda.set_device(f"cuda:{args.cuda_device}")
with open(args.output_file, 'w') as sys.stdout:

configurations = []
for train_batch_size in global_batch_size:
for seq_length in s:
for tensor_mp_size in t:
for num_attention_heads in a:
for hidden_size in h:
for microbatch_size in b:
for vocab_size in v:
configurations.append((microbatch_size, hidden_size,
(tensor_mp_size, 1, 1), num_attention_heads,vocab_size,seq_length,train_batch_size))
megatron_wrapper.initialize_megatron(configurations[0])
for configuration in configurations:
(microbatch_size, hidden_size,
(tensor_mp_size, pipeline_mp_size, dp_size), num_attention_heads,vocab_size,seq_length,train_batch_size) = configuration
label = {'num_attention_heads': num_attention_heads,
'hidden_size': hidden_size,
'train_micro_batch_size_per_gpu': microbatch_size,
'seq_length': seq_length,
'vocab_size': vocab_size,
'train_batch_size': train_batch_size,
'tensor_mp_size': tensor_mp_size,
'pipeline_mp_size': pipeline_mp_size,
'dp_size': dp_size}
label_str = ", ".join([f"{k}: {v}" for (k, v) in label.items()])
print(label_str)
if args.blocks is None:
benchmark_transformer(args,configuration, seq_length, train_batch_size, args.num_iterations, args.num_warmup_iterations)
else:
benchmark_transformer_from_mm_and_bmm(args,configuration, seq_length, train_batch_size, args.num_iterations, args.num_warmup_iterations)
print("=" * 120)
sys.stdout = Tee(args.output_file, args.verbose)

configurations = []
for train_batch_size in global_batch_size:
for seq_length in s:
for tensor_mp_size in t:
for num_attention_heads in a:
for hidden_size in h:
for microbatch_size in b:
for vocab_size in v:
configurations.append((microbatch_size, hidden_size,
(tensor_mp_size, 1, 1), num_attention_heads,vocab_size,seq_length,train_batch_size))
megatron_wrapper.initialize_megatron(configurations[0])
for configuration in configurations:
(microbatch_size, hidden_size,
(tensor_mp_size, pipeline_mp_size, dp_size), num_attention_heads,vocab_size,seq_length,train_batch_size) = configuration
label = {'num_attention_heads': num_attention_heads,
'hidden_size': hidden_size,
'train_micro_batch_size_per_gpu': microbatch_size,
'seq_length': seq_length,
'vocab_size': vocab_size,
'train_batch_size': train_batch_size,
'tensor_mp_size': tensor_mp_size,
'pipeline_mp_size': pipeline_mp_size,
'dp_size': dp_size}
label_str = ", ".join([f"{k}: {v}" for (k, v) in label.items()])
print(label_str)
if args.blocks is None:
benchmark_transformer(args,configuration, seq_length, train_batch_size, args.num_iterations, args.num_warmup_iterations)
else:
benchmark_transformer_from_mm_and_bmm(args,configuration, seq_length, train_batch_size, args.num_iterations, args.num_warmup_iterations)
print("=" * 120)
19 changes: 19 additions & 0 deletions benchmarks/sizing/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import torch
import numpy as np
from megatron.model import LayerNorm
Expand All @@ -8,6 +9,24 @@
from megatron.model.gpt2_model import gpt2_attention_mask_func as attention_mask_func
from megatron.model.word_embeddings import Embedding

class Tee(object):
def __init__(self, filename, verbose):
self.file = open(filename, "w")
self.verbose = verbose
if self.verbose:
self.stdout = sys.stdout

def write(self, message):
self.file.write(message)
if self.verbose:
self.stdout.write(message)

def flush(self):
self.file.flush()
if self.verbose:
self.stdout.flush()


def display(shape):
return "x".join([str(dim) for dim in shape])

Expand Down

0 comments on commit 3d3011e

Please sign in to comment.