Enable fused rmsnorm in bf16 for llama#621
Conversation
|
@regisss - please review. We can enable fused rmsnorm in lower precision too and this gives a boost in performance too. Command -> python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py --model_name_or_path /software/data/llama_inference/Llama-2-70b-hf/ --max_new_tokens ?? --bf16 --n_iterations 3 --use_hpu_graphs --use_kv_cache --batch_size ?? --reuse_cache --limit_hpu_graphs --trim_logits --warmup 2 --attn_softmax_bf16 See below table for improved perf results: <style> </style>
|
|
LGTM, was also verified in FP8 runs. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
regisss
left a comment
There was a problem hiding this comment.
Nice! Does it generate the same outputs as before?
|
@regisss , let me check finetuning as well for perf and accuracy. |
This reverts commit b72d8ea.
What does this PR do?
Fixes # (issue)
Before submitting