diff --git a/calc/README.md b/calc/README.md index b4e2381..259283a 100644 --- a/calc/README.md +++ b/calc/README.md @@ -92,8 +92,11 @@ Example with pythia 6.9B: python calc_transformer_mem.py --num-layers=32 --seque Example with pythia 12B: python calc_transformer_mem.py --num-layers=36 --sequence-length=2048 --num-attention-heads=40 --hidden-size=5120 --batch-size-per-gpu=8 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=4 --num-gpus=256 Example with default 20B: python calc_transformer_mem.py --num-layers=44 --sequence-length=2048 --num-attention-heads=64 --hidden-size=6144 --batch-size-per-gpu=1 --checkpoint-activations --zero-stage=1 --partition-activations --pipeline-parallel-size=1 --tensor-parallel-size=1 --num-gpus=1 -usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--num-attention-heads NUM_ATTENTION_HEADS] - [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--disable-mixed-precision] [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM] [--misc-mem-gib MISC_MEM_GIB] +usage: calc_transformer_mem.py [-h] [--num-gpus NUM_GPUS] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--pipeline-parallel-size PIPELINE_PARALLEL_SIZE] [--partition-activations] [--zero-stage {0,1,2,3}] [--zero-allgather-bucket-size ZERO_ALLGATHER_BUCKET_SIZE] + [--zero3-max-live-params ZERO3_MAX_LIVE_PARAMS] [--checkpoint-activations] [--batch-size-per-gpu BATCH_SIZE_PER_GPU] [--sequence-length SEQUENCE_LENGTH] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] + [--num-attention-heads NUM_ATTENTION_HEADS] [--num-layers NUM_LAYERS] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--infer] [--kv-size-ratio KV_SIZE_RATIO] [--output-tokens OUTPUT_TOKENS] [--disable-mixed-precision] + [--high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL] [--low-prec-bytes-per-val LOW_PREC_BYTES_PER_VAL] [--bytes-per-grad-ele BYTES_PER_GRAD_ELE] [--num-experts NUM_EXPERTS] [--expert-parallelism EXPERT_PARALLELISM] + [--misc-mem-gib MISC_MEM_GIB] options: -h, --help show this help message and exit @@ -129,6 +132,8 @@ options: --infer whether we're doing inference --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA. + --output-tokens OUTPUT_TOKENS, -o OUTPUT_TOKENS + Number of tokens to autoregressively generate. --disable-mixed-precision Disables mixed precision training --high-prec-bytes-per-val HIGH_PREC_BYTES_PER_VAL diff --git a/calc/calc_transformer_mem.py b/calc/calc_transformer_mem.py index dae11c6..a3a8bb1 100644 --- a/calc/calc_transformer_mem.py +++ b/calc/calc_transformer_mem.py @@ -85,6 +85,10 @@ def config_parser(): type=float, default=1.0, help='Ratio of total query heads to key/value heads. 1.0 for MHA, 1/num_attention_heads for MQA.') + parser.add_argument("--output-tokens", "-o", + type=int, + default=1, + help='Number of tokens to autoregressively generate.') # Precision settings parser.add_argument("--disable-mixed-precision", action="store_false", @@ -211,7 +215,7 @@ def calc_mem(args): if args.infer: # See https://kipp.ly/transformer-inference-arithmetic/ for details bytes_per_param = args.low_prec_bytes_per_val - per_gpu_kv_cache_mem = bytes_per_param * 2 * args.num_layers * args.num_attention_heads * (args.hidden_size / args.num_attention_heads) * args.sequence_length + per_gpu_kv_cache_mem = bytes_per_param * args.hidden_size * args.num_layers * (args.sequence_length + args.output_tokens) * (args.batch_size_per_gpu) kv_cache_mem = args.num_gpus * per_gpu_kv_cache_mem gradient_mem_gib = gradient_mem / 1024**3