Skip to content
9 changes: 9 additions & 0 deletions scripts/performance/configs/llama/llama3_llm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def llama3_70b_lora_config_gb300(precision: str = "bf16", config_variant: str =
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

# Override target_modules to only apply LoRA to QKV
cfg.peft.target_modules = ["linear_qkv"]

# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
# This ensures consistent cu_seqlens tensor shapes across batches, which is required
# for CUDA graphs and avoids NaN issues in attention kernels.
Expand Down Expand Up @@ -240,6 +243,9 @@ def llama3_70b_lora_config_gb200(precision: str = "bf16", config_variant: str =
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

# Override target_modules to only apply LoRA to QKV
cfg.peft.target_modules = ["linear_qkv"]

return cfg


Expand All @@ -264,4 +270,7 @@ def llama3_70b_lora_config_h100(precision: str = "bf16", config_variant: str = "
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

# Override target_modules to only apply LoRA to QKV
cfg.peft.target_modules = ["linear_qkv"]

return cfg
Loading