Skip to content

Commit

Permalink
Merge pull request #314 from sillsdev/lora_implementation
Browse files Browse the repository at this point in the history
Add configurable support for LoRA
  • Loading branch information
isaac091 authored Jan 31, 2024
2 parents 4437920 + 7843b67 commit 7babae0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 2 deletions.
31 changes: 30 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ oauth2client = "^4.1.3"
gspread = "^5.11.2"
pydrive2 = "^1.17.0"
jinja2 = "^3.1.2"
peft = "0.7.0"

[tool.poetry.group.dev.dependencies]
types-pyyaml = "^6.0.12.12"
Expand Down
42 changes: 41 additions & 1 deletion silnlp/nmt/hugging_face_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import safetensors.torch

if is_peft_available():
from peft import PeftModel
from peft import LoraConfig, PeftModel, TaskType, get_peft_model

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,6 +147,16 @@ def prepare_decoder_input_ids_from_labels(self: M2M100ForConditionalGeneration,
},
}

LORA_DEFAULT_CONFIGS = {
"facebook/nllb-200": {
"target_modules": ["q_proj", "v_proj", "k_proj", "out_proj", "fc1", "fc2"],
"modules_to_save": ["embed_tokens"]
},
"google/madlad400": {
"target_modules": ["q", "v", "k", "o", "wi_0", "wi_1", "wo"],
"modules_to_save": ["embed_tokens"]
}
}

def get_best_checkpoint(model_dir: Path) -> Path:
trainer_state_path = model_dir / "trainer_state.json"
Expand Down Expand Up @@ -250,6 +260,8 @@ def __init__(self, exp_dir: Path, config: dict) -> None:
"delete_checkpoint_optimizer_state": True,
"delete_checkpoint_tokenizer": True,
"log_level": "info",
"use_lora": False,
"lora_config": {}
},
"eval": {
"evaluation_strategy": "steps",
Expand Down Expand Up @@ -664,6 +676,34 @@ def train(self) -> None:
elif len(tokenizer) != old_num_tokens:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8 if training_args.fp16 else None)

if self._config.train["use_lora"]:
lora_config = self._config.train["lora_config"]

if "target_modules" not in lora_config:
for model_prefix in LORA_DEFAULT_CONFIGS:
if self._config.model.startswith(model_prefix):
lora_config["target_modules"] = LORA_DEFAULT_CONFIGS[model_prefix]["target_modules"]
if "modules_to_save" not in lora_config:
for model_prefix in LORA_DEFAULT_CONFIGS:
if self._config.model.startswith(model_prefix):
lora_config["modules_to_save"] = LORA_DEFAULT_CONFIGS[model_prefix]["modules_to_save"]

if isinstance(lora_config["target_modules"], str):
lora_config["target_modules"] = lora_config["target_modules"].split(",")
if isinstance(lora_config["modules_to_save"], str):
lora_config["modules_to_save"] = lora_config["modules_to_save"].split(",")

peft_config = LoraConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
r=lora_config.get("r", 4),
lora_alpha=lora_config.get("alpha", 32),
lora_dropout=lora_config.get("dropout", .1),
target_modules=lora_config["target_modules"],
modules_to_save=lora_config["modules_to_save"],
)
model = get_peft_model(model, peft_config)
model.enable_input_require_grads() # Converting to PeftModel causes gradient calculation to be disabled

# Set decoder_start_token_id
if (
self._config.val_trg_lang != ""
Expand Down

0 comments on commit 7babae0

Please sign in to comment.