Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions scripts/performance/configs/llama/llama3_llm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ def llama3_70b_lora_config_b300(precision: str = "bf16", config_variant: str = "
)
precision_config = get_precision_config(precision)

cfg = llama3_70b_finetune_config(
peft="lora",
precision_config=precision_config,
packed_sequence=True,
seq_length=4096,
)
cfg = llama3_70b_peft_config(peft_scheme="lora")
cfg.mixed_precision = precision_config
seq_length = 4096
cfg.model.seq_length = seq_length
cfg.dataset.seq_length = seq_length
cfg.dataset.packed_sequence_specs.packed_sequence_size = seq_length
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)
# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
Expand Down Expand Up @@ -307,12 +307,12 @@ def llama3_70b_lora_config_b200(precision: str = "bf16", config_variant: str = "
)
precision_config = get_precision_config(precision)

cfg = llama3_70b_finetune_config(
peft="lora",
precision_config=precision_config,
packed_sequence=True,
seq_length=4096,
)
cfg = llama3_70b_peft_config(peft_scheme="lora")
cfg.mixed_precision = precision_config
seq_length = 4096
cfg.model.seq_length = seq_length
cfg.dataset.seq_length = seq_length
cfg.dataset.packed_sequence_specs.packed_sequence_size = seq_length
set_llama3_common_peft_configs(cfg)
set_workload_base_configs(cfg, base_cfg)
# Enable pad_cu_seqlens for CUDA graphs compatibility with packed sequences.
Expand Down
Loading