Skip to content

Commit

Permalink
Add tied embeddings and output tokens option (#41)
Browse files Browse the repository at this point in the history
* Add tied embedding option

* update readme with new tied embedding flag

* add output tokens to kv-mem calc

* add output tokens arg to mem calc readme
  • Loading branch information
Quentin-Anthony committed Aug 11, 2024
1 parent 86ff050 commit d4ca9a0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 7 additions & 2 deletions calc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ Example with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --seque
Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256
Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1
usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--num-attention-heads NUM_ATTENTION_HEADS]
[--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM] [--misc-mem-gib MISC_MEM_GIB]
usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE]
[--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE]
[--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--output-tokens OUTPUT_TOKENS] [--disable-mixed-precision]
[--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM]
[--misc-mem-gib MISC_MEM_GIB]
options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -129,6 +132,8 @@ options:
--infer whether we're doing inference
--kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO
Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.
--output-tokens OUTPUT_TOKENS, -o OUTPUT_TOKENS
Number of tokens to autoregressively generate.
--disable-mixed-precision
Disables mixed precision training
--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL
Expand Down
6 changes: 5 additions & 1 deletion calc/calc_transformer_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def config_parser():
type=float,
default=1.0,
help='Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.')
parser.add_argument("--output-tokens", "-o",
type=int,
default=1,
help='Number of tokens to autoregressively generate.')
# Precision settings
parser.add_argument("--disable-mixed-precision",
action="store_false",
Expand Down Expand Up @@ -211,7 +215,7 @@ def calc_mem(args):
if args.infer:
# See https://kipp.ly/transformer-inference-arithmetic/ for details
bytes_per_param = args.low_prec_bytes_per_val
per_gpu_kv_cache_mem = bytes_per_param * 2 * args.num_layers * args.num_attention_heads * (args.hidden_size / args.num_attention_heads) * args.sequence_length
per_gpu_kv_cache_mem = bytes_per_param * args.hidden_size * args.num_layers * (args.sequence_length + args.output_tokens) * (args.batch_size_per_gpu)
kv_cache_mem = args.num_gpus * per_gpu_kv_cache_mem

gradient_mem_gib = gradient_mem / 1024**3
Expand Down

0 comments on commit d4ca9a0

Please sign in to comment.