diff --git a/tests/qlora/README.md b/tests/qlora/README.md new file mode 100644 index 0000000000..e535c38760 --- /dev/null +++ b/tests/qlora/README.md @@ -0,0 +1,47 @@ +## QLoRA Train and Merge Tests + +### Overview +Tests that performing QLoRA training and merging weights to 16-bits post-training maintains same behavior as trained model. + +- `test_unsloth_qlora_train_and_merge.py`: Test Unsloth QLoRA train and merge using `FastLanguageModel.from_pretrained`, `FastLanguageModel.get_peft_model`, and `FastLanguageModel.save_pretrained_merged` apis +- `test_hf_qlora_train_and_merge.py`: Test Hugging Face QLoRA train and merge using `from_pretrained`, `get_peft_model`, and `merge_and_unload` apis. + - Demonstrates that `peft`'s `merge_and_unload` results in loss of accuracy as it requantizes the base layer after merging adapter weights so that the model still contains `Linear4Bit` layers post merging. + - I (@jeromeku) implemented a custom merge function that replaces all `LoraLayers` with `Linear` layers whose weights are the dequantized base layer weights with adapter weights merged (compute done in fp32, cast to original dtype after merging), roughly equivalent to `FastLanguageModel.save_pretrained_merged`. + +### Usage +Run unsloth test: +```bash +python tests/qlora/test_unsloth_qlora_train_and_merge.py +``` +Run huggingface test: +```bash +python tests/qlora/test_hf_qlora_train_and_merge.py +``` + +### Details +The tests train a QLoRA model on a single prompt dataset +``` +QUESTION = "What day was I born?" +ANSWER = "January 1, 2058" +USER_MESSAGE = {"role": "user", "content": QUESTION} +ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} +``` + +Given that the answer is impossible to answer accurately without finetuning, we can only expect the model to answer the question correctly if the model has been trained on the question. + +To check this behavior, we check the model's response to the question before and after training and after merging, checking that the model's response contains the answer after training and merging but not before training. + +### Results + +For the unsloth test, the model's behavior is as expected: +- before training, the model's response does not contain the answer +- after training, the model's response contains the answer +- after merging, the model's response contains the answer + +For the huggingface test, the model's behavior is as expected: +- before training, the model's response does not contains the answer +- after training, the model's response contains the answer +- after using peft's `merge_and_unload`, the model's response does not contain the answer +- after using my custom merge function, the model's response contains the answer + +The scripts should output training params, training logs, as well as model responses before and after training and after merging (only prints model responses if answer is not contained in response). \ No newline at end of file diff --git a/tests/qlora/test_hf_qlora_train_and_merge.py b/tests/qlora/test_hf_qlora_train_and_merge.py new file mode 100644 index 0000000000..797d940180 --- /dev/null +++ b/tests/qlora/test_hf_qlora_train_and_merge.py @@ -0,0 +1,159 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).parents[2] +sys.path.append(str(REPO_ROOT)) + +import itertools +from copy import deepcopy + +import torch +from datasets import Dataset +from trl import SFTConfig +from tests.utils import header_footer_context +from tests.utils.data_utils import ( + ANSWER, + DEFAULT_MESSAGES, + USER_MESSAGE, + check_responses, + create_dataset, + describe_peft_weights, +) +from tests.utils.hf_utils import ( + convert_lora_to_linear, + fix_llama3_tokenizer, + get_peft_config, + sample_responses, + setup_model, + setup_tokenizer, + setup_trainer, +) + +if __name__ == "__main__": + model_name = "meta-llama/Llama-3.2-1B-Instruct" + dtype = torch.bfloat16 + max_steps = 100 + num_examples = 1000 + lora_rank = 64 + output_dir = "sft_test" + seed = 42 + batch_size = 5 + num_generations = 5 + tokenizer = setup_tokenizer(model_name, fixup_funcs=[fix_llama3_tokenizer]) + temperature = 0.8 + max_new_tokens = 20 + + peft_config = get_peft_config(lora_rank=lora_rank, target_modules="all-linear") + model = setup_model(model_name, quantize=True, dtype=dtype, peft_config=peft_config) + + prompt = tokenizer.apply_chat_template( + [USER_MESSAGE], tokenize=False, add_generation_prompt=True + ) + with header_footer_context("Test Prompt and Answer"): + print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") + + dataset: Dataset = create_dataset( + tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + ) + with header_footer_context("Dataset"): + print(f"Dataset: {next(iter(dataset))}") + + training_args = SFTConfig( + output_dir=output_dir, + max_steps=max_steps, + per_device_train_batch_size=batch_size, + log_level="info", + report_to="none", + num_train_epochs=1, + logging_steps=1, + seed=seed, + bf16=dtype == torch.bfloat16, + fp16=dtype == torch.float16, + save_strategy="no", + ) + + with header_footer_context("Train Args"): + print(training_args) + print(peft_config) + + trainer = setup_trainer( + model, tokenizer, dataset, training_args, peft_config=peft_config + ) + + with header_footer_context("Model"): + print(type(model.model)) + + generation_args = { + "num_generations": num_generations, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "skip_special_tokens": False, + "dtype": dtype, + } + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses before training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + with header_footer_context("Peft Weights before training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + output = trainer.train() + with header_footer_context("Peft Weights after training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + with header_footer_context("Trainer Output"): + print(output) + + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + model_copy = deepcopy(model) + + merged_model = convert_lora_to_linear(model) + + responses = sample_responses( + merged_model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after custom merging to 16bit"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + merged_model_peft = model_copy.merge_and_unload() + responses = sample_responses( + merged_model_peft, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after peft merge_and_unload"): + check_responses(responses, answer=ANSWER, prompt=prompt) diff --git a/tests/qlora/test_unsloth_qlora_train_and_merge.py b/tests/qlora/test_unsloth_qlora_train_and_merge.py new file mode 100644 index 0000000000..59fa813fa6 --- /dev/null +++ b/tests/qlora/test_unsloth_qlora_train_and_merge.py @@ -0,0 +1,211 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).parents[2] +sys.path.append(str(REPO_ROOT)) + +import itertools +from unsloth import FastLanguageModel + +import torch +from datasets import Dataset +from trl import SFTConfig +from tests.utils import header_footer_context +from tests.utils.data_utils import ( + DEFAULT_MESSAGES, + USER_MESSAGE, + ANSWER, + create_dataset, + describe_peft_weights, + check_responses, +) +from tests.utils.hf_utils import ( + sample_responses, + setup_trainer, +) + + +def get_unsloth_model_and_tokenizer( + model_name: str, + max_seq_length: int, + load_in_4bit: bool, + fast_inference: bool, + max_lora_rank: int = None, + gpu_memory_utilization: float = 0.5, + dtype: torch.dtype = torch.bfloat16, +): + return FastLanguageModel.from_pretrained( + model_name=model_name, + max_seq_length=max_seq_length, + load_in_4bit=load_in_4bit, + fast_inference=fast_inference, + max_lora_rank=max_lora_rank, + gpu_memory_utilization=gpu_memory_utilization, + dtype=dtype, + ) + + +def get_unsloth_peft_model( + model, + lora_rank: int, + target_modules: list[str] = "all-linear", + use_gradient_checkpointing: str = False, + random_state: int = 42, +): + return FastLanguageModel.get_peft_model( + model, + r=lora_rank, + target_modules=target_modules, + lora_alpha=lora_rank, + use_gradient_checkpointing=use_gradient_checkpointing, + random_state=random_state, + ) + + +if __name__ == "__main__": + model_name = "meta-llama/Llama-3.2-1B-Instruct" + dtype = torch.bfloat16 + max_steps = 100 + num_examples = 1000 + lora_rank = 64 + output_dir = "sft_test" + seed = 42 + batch_size = 5 + num_generations = 5 + target_modules = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ] + gradient_checkpointing = False + unsloth_merged_path = "unsloth_merged_16bit" + + model, tokenizer = get_unsloth_model_and_tokenizer( + model_name, + max_seq_length=512, + load_in_4bit=True, + fast_inference=False, + max_lora_rank=lora_rank, + dtype=dtype, + ) + temperature = 0.8 + max_new_tokens = 20 + + model = get_unsloth_peft_model( + model, + lora_rank=lora_rank, + target_modules=target_modules, + use_gradient_checkpointing=gradient_checkpointing, + random_state=seed, + ) + + prompt = tokenizer.apply_chat_template( + [USER_MESSAGE], tokenize=False, add_generation_prompt=True + ) + + with header_footer_context("Test Prompt and Answer"): + print(f"Test Prompt:\n{prompt}\nExpected Answer:\n{ANSWER}") + + dataset: Dataset = create_dataset( + tokenizer, num_examples=num_examples, messages=DEFAULT_MESSAGES + ) + with header_footer_context("Dataset"): + print(f"Dataset: {next(iter(dataset))}") + + training_args = SFTConfig( + output_dir=output_dir, + max_steps=max_steps, + per_device_train_batch_size=batch_size, + log_level="info", + report_to="none", + num_train_epochs=1, + logging_steps=1, + seed=seed, + bf16=dtype == torch.bfloat16, + fp16=dtype == torch.float16, + save_strategy="no", + ) + + with header_footer_context("Train Args"): + print(training_args) + + trainer = setup_trainer(model, tokenizer, dataset, training_args) + + with header_footer_context("Model"): + print(type(model.model)) + + generation_args = { + "num_generations": num_generations, + "max_new_tokens": max_new_tokens, + "temperature": temperature, + "skip_special_tokens": False, + "dtype": dtype, + } + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses before training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + with header_footer_context("Peft Weights before training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + output = trainer.train() + with header_footer_context("Peft Weights after training"): + for name, stats in itertools.islice(describe_peft_weights(model), 2): + print(f"{name}:\n{stats}") + + with header_footer_context("Trainer Output"): + print(output) + + responses = sample_responses( + model, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after training"): + check_responses(responses, answer=ANSWER, prompt=prompt) + + model.save_pretrained_merged( + unsloth_merged_path, + tokenizer, + save_method="merged_16bit", + ) + merged_model_unsloth, tokenizer = get_unsloth_model_and_tokenizer( + unsloth_merged_path, + max_seq_length=512, + load_in_4bit=False, + fast_inference=False, + dtype=dtype, + ) + responses = sample_responses( + merged_model_unsloth, + tokenizer, + prompt=prompt, + **generation_args, + ) + with header_footer_context("Responses after unsloth merge to 16bit"): + check_responses(responses, answer=ANSWER, prompt=prompt) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 0000000000..cd5d0d96c7 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from contextlib import contextmanager + + +@contextmanager +def timer(name): + start = time.time() + yield + end = time.time() + print(f"{name} took {end - start:.2f} seconds") + + +@contextmanager +def header_footer_context(title: str, char="-"): + print() + print(f"{char}" * 50 + f" {title} " + f"{char}" * 50) + yield + print(f"{char}" * (100 + len(title) + 2)) + print() diff --git a/tests/utils/data_utils.py b/tests/utils/data_utils.py new file mode 100644 index 0000000000..7682fe4807 --- /dev/null +++ b/tests/utils/data_utils.py @@ -0,0 +1,153 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from datasets import Dataset + +QUESTION = "What day was I born?" +ANSWER = "January 1, 2058" +USER_MESSAGE = {"role": "user", "content": QUESTION} +ASSISTANT_MESSAGE = {"role": "assistant", "content": ANSWER} +DTYPE = torch.bfloat16 +DEFAULT_MESSAGES = [[USER_MESSAGE, ASSISTANT_MESSAGE]] + + +def create_instruction_dataset(messages: list[dict] = DEFAULT_MESSAGES): + dataset = Dataset.from_dict({"messages": messages}) + return dataset + + +def create_dataset(tokenizer, num_examples: int = None, messages: list[dict] = None): + dataset = create_instruction_dataset(messages) + + def _apply_chat_template(example): + chat = tokenizer.apply_chat_template(example["messages"], tokenize=False) + return {"text": chat} + + dataset = dataset.map(_apply_chat_template, remove_columns="messages") + if num_examples is not None: + if len(dataset) < num_examples: + num_repeats = num_examples // len(dataset) + 1 + dataset = dataset.repeat(num_repeats) + dataset = dataset.select(range(num_examples)) + + return dataset + + +def describe_param( + param: torch.Tensor, + include_l1: bool = False, + include_l2: bool = False, + include_infinity: bool = False, + as_str: bool = True, +) -> dict: + """ + Provide a statistical summary of a 2D weight matrix or tensor. + If as_str is True, the summary is returned as a formatted string. + Parameters: + param: torch.Tensor + include_l1 (bool): Whether to include the L1 norm (sum of absolute values). + include_l2 (bool): Whether to include the L2 norm (Frobenius norm). + include_infinity (bool): Whether to include the infinity norm (max absolute value). + as_str (bool): Whether to return the summary as a formatted string. + + Returns: + dict: A dictionary with the following statistics: + - shape: Dimensions of the matrix. + - mean: Average value. + - median: Median value. + - std: Standard deviation. + - min: Minimum value. + - max: Maximum value. + - percentile_25: 25th percentile. + - percentile_75: 75th percentile. + Additionally, if enabled: + - L1_norm: Sum of absolute values. + - L2_norm: Euclidean (Frobenius) norm. + - infinity_norm: Maximum absolute value. + """ + + param = param.float() + summary = { + "shape": param.shape, + "mean": param.mean().cpu().item(), + "std": param.std().cpu().item(), + "min": param.min().cpu().item(), + "max": param.max().cpu().item(), + "percentile_25": param.quantile(0.25).cpu().item(), + "percentile_50": param.quantile(0.5).cpu().item(), + "percentile_75": param.quantile(0.75).cpu().item(), + } + + if include_l1: + summary["L1_norm"] = param.abs().sum().cpu().item() + if include_l2: + summary["L2_norm"] = param.norm().cpu().item() + if include_infinity: + summary["infinity_norm"] = param.abs().max().cpu().item() + + return format_summary(summary) if as_str else summary + + +def format_summary(stats: dict, precision: int = 6) -> str: + """ + Format the statistical summary dictionary for printing. + + Parameters: + stats (dict): The dictionary returned by describe_param. + precision (int): Number of decimal places for floating point numbers. + + Returns: + str: A formatted string representing the summary. + """ + lines = [] + for key, value in stats.items(): + if isinstance(value, float): + formatted_value = f"{value:.{precision}f}" + elif isinstance(value, (tuple, list)): + # Format each element in tuples or lists (e.g., the shape) + formatted_value = ", ".join(str(v) for v in value) + formatted_value = ( + f"({formatted_value})" + if isinstance(value, tuple) + else f"[{formatted_value}]" + ) + else: + formatted_value = str(value) + lines.append(f"{key}: {formatted_value}") + return "\n".join(lines) + + +def get_peft_weights(model): + # ruff: noqa + is_lora_weight = lambda name: any(s in name for s in ["lora_A", "lora_B"]) + return { + name: param for name, param in model.named_parameters() if is_lora_weight(name) + } + + +def describe_peft_weights(model): + for name, param in get_peft_weights(model).items(): + yield name, describe_param(param, as_str=True) + + +def check_responses(responses: list[str], answer: str, prompt: str = None) -> bool: + for i, response in enumerate(responses, start=1): + if answer in response: + print(f"\u2713 response {i} contains answer") + else: + print(f"\u2717 response {i} does not contain answer") + if prompt is not None: + response = response.replace(prompt, "") + print(f" -> response: {response}") diff --git a/tests/utils/hf_utils.py b/tests/utils/hf_utils.py new file mode 100644 index 0000000000..cc5edce021 --- /dev/null +++ b/tests/utils/hf_utils.py @@ -0,0 +1,291 @@ +# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager, nullcontext +from typing import Callable, Optional + +import bitsandbytes as bnb +import torch +from bitsandbytes.functional import dequantize_4bit +from peft import get_peft_model, prepare_model_for_kbit_training +from peft.tuners.lora import LoraConfig, LoraLayer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, +) +from transformers.trainer_callback import ( + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from trl import SFTTrainer + + +class PeftWeightCallback(TrainerCallback): + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs, + **kwargs, + ): + print(f"DEBUG::CALLBACK::on_log::{state.log_history}") + + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.get("model") + assert model is not None + print(f"DEBUG::CALLBACK::on_train_begin::{kwargs.keys()}") + + def on_step_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + print(f"DEBUG::CALLBACK::on_step_end::{state.global_step}") + + +@torch.inference_mode() +def generate_responses( + model, + tokenizer, + prompt, + max_new_tokens: int = 100, + temperature: float = 0.8, + do_sample: bool = True, + num_generations: int = 1, + skip_special_tokens: bool = True, + dtype: torch.dtype = None, +): + inputs = [tokenizer(prompt, return_tensors="pt") for _ in range(num_generations)] + keys = inputs[0].keys() + batched_inputs = { + key: torch.cat([input[key] for input in inputs], dim=0).to(model.device) + for key in keys + } + + if dtype is not None: + inference_context = torch.autocast(device_type="cuda", dtype=dtype) + else: + inference_context = nullcontext() + + with inference_context: + outputs = model.generate( + **batched_inputs, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + ) + + responses = tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens) + return responses + + +def sample_responses( + model, + tokenizer, + prompt, + temperature: float = 0.8, + num_generations: int = 1, + max_new_tokens: int = 100, + skip_special_tokens: bool = True, + dtype: torch.dtype = None, +): + responses = generate_responses( + model, + tokenizer, + prompt, + temperature=temperature, + num_generations=num_generations, + max_new_tokens=max_new_tokens, + skip_special_tokens=skip_special_tokens, + dtype=dtype, + ) + return responses + + +def setup_tokenizer(model_name, fixup_funcs: list[Callable] = []): + tokenizer = AutoTokenizer.from_pretrained(model_name) + for fixup_func in fixup_funcs: + tokenizer = fixup_func(tokenizer) + return tokenizer + + +def setup_model( + model_name, + quantize: bool = True, + dtype=torch.bfloat16, + peft_config=None, + autocast_adapter: bool = True, +): + if quantize: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=dtype, + ) + else: + bnb_config = None + + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="cuda:0", + attn_implementation="sdpa", + quantization_config=bnb_config, + torch_dtype=dtype, + ) + model = prepare_model_for_kbit_training(model) if quantize else model + + if peft_config is not None: + model = get_peft_model( + model, peft_config, autocast_adapter_dtype=autocast_adapter + ) + + return model + + +def get_peft_config( + lora_rank, + lora_alpha=None, + lora_dropout=0.0, + bias="none", + target_modules="all-linear", +): + lora_alpha = lora_alpha or 2 * lora_rank + peft_config = LoraConfig( + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + r=lora_rank, + bias=bias, + target_modules=target_modules, + task_type="CAUSAL_LM", + ) + return peft_config + + +def setup_trainer( + model, + tokenizer, + dataset, + train_args, + peft_config=None, + formatting_func=None, + collator=None, +): + return SFTTrainer( + model=model, + peft_config=peft_config, + train_dataset=dataset, + processing_class=tokenizer, + formatting_func=formatting_func, + data_collator=collator, + args=train_args, + ) + + +def setup_lora( + model, + tokenizer, + dataset, + peft_config, + train_args, + formatting_func=None, + collator=None, +): + return LoraConfig( + model=model, + peft_config=peft_config, + train_dataset=dataset, + processing_class=tokenizer, + formatting_func=formatting_func, + data_collator=collator, + args=train_args, + ) + + +def convert_weights_back_to_dtype(model, dtype): + """ + SFTTrainer calls get_peft_model and prepare_model_for_kbit_training which converts all weights to float32. + This function converts the non-loraweights back to the original dtype. + """ + for name, param in model.named_parameters(): + if any(s in name for s in ["norm", "embed"]): + param.data = param.data.to(dtype) + + +def fix_llama3_tokenizer(tokenizer, padding_side="right"): + tokenizer.padding_side = padding_side + added_vocab = tokenizer.get_added_vocab() + pad_token = [w for w in added_vocab if "pad" in w] + assert len(pad_token) == 1 + tokenizer.pad_token = pad_token[0] # Load dataset from the hub + return tokenizer + + +def replace_module( + module: torch.nn.Module, + target_module_type: torch.nn.Module, + conversion_func: Callable, +): + for child_name, child_module in module.named_children(): + if isinstance(child_module, target_module_type): + new_module = conversion_func(child_module) + setattr(module, child_name, new_module) + else: + replace_module(child_module, target_module_type, conversion_func) + + +def _convert_lora_to_linear(module: LoraLayer, adapter_name: str = "default"): + base_layer = module.get_base_layer() + weight = base_layer.weight + + assert isinstance(weight, bnb.nn.Params4bit) + quant_state = weight.quant_state + original_dtype = quant_state.dtype + + w_dq = dequantize_4bit(weight.data, quant_state).float() + lora_delta = ( + module.lora_B[adapter_name].weight + @ module.lora_A[adapter_name].weight + * module.scaling[adapter_name] + ) + w_dq += lora_delta.float() + w_dq = w_dq.to(original_dtype) + + new_module = torch.nn.Linear( + w_dq.shape[1], w_dq.shape[0], bias=module.base_layer.bias is not None + ) + new_module.weight.data = torch.nn.Parameter(w_dq, requires_grad=False) + if module.lora_bias[adapter_name]: + bias_data = module.base_layer.bias.data + module.lora_B[adapter_name].bias + new_module.bias.data = torch.nn.Parameter(bias_data, requires_grad=False) + return new_module + + +def convert_lora_to_linear(model: torch.nn.Module): + replace_module(model, LoraLayer, _convert_lora_to_linear) + assert not any(isinstance(module, LoraLayer) for module in model.modules()) + return model