diff --git a/scripts/performance/configs/llama/llama3_workload_base_configs.py b/scripts/performance/configs/llama/llama3_workload_base_configs.py index affd813dba..29a18a04be 100644 --- a/scripts/performance/configs/llama/llama3_workload_base_configs.py +++ b/scripts/performance/configs/llama/llama3_workload_base_configs.py @@ -467,7 +467,7 @@ peft="none", micro_batch_size=1, global_batch_size=8, - cuda_graph_impl="transformer_engine", + cuda_graph_impl="none", # NOTE: CUDA Graphs reduces performance here cuda_graph_scope="mlp", ) @@ -486,7 +486,7 @@ LLAMA3_8B_SFT_CONFIG_H100_BF16_V1 = _LLAMA3_8B_SFT_CONFIG_H100 LLAMA3_8B_SFT_CONFIG_H100_FP8_CS_V1 = replace( _LLAMA3_8B_SFT_CONFIG_H100, - cuda_graph_impl="transformer_engine", + cuda_graph_impl="none", cuda_graph_scope="mlp", )