From b7c6a9e00369fe75ed3b15673df9fc33607d5364 Mon Sep 17 00:00:00 2001 From: 0xD4rky Date: Tue, 10 Jun 2025 19:34:43 +0530 Subject: [PATCH 1/3] adding code for qlora support --- .../trainer/models/florence_2/checkpoints.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/maestro/trainer/models/florence_2/checkpoints.py b/maestro/trainer/models/florence_2/checkpoints.py index cc61b3db..7e7e5825 100644 --- a/maestro/trainer/models/florence_2/checkpoints.py +++ b/maestro/trainer/models/florence_2/checkpoints.py @@ -4,7 +4,7 @@ import torch from peft import LoraConfig, get_peft_model -from transformers import AutoModelForCausalLM, AutoProcessor +from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig from maestro.trainer.common.utils.device import parse_device_spec from maestro.trainer.logger import get_maestro_logger @@ -26,6 +26,7 @@ class OptimizationStrategy(Enum): """Enumeration for optimization strategies.""" LORA = "lora" + QLORA = "qlora" FREEZE = "freeze" NONE = "none" @@ -58,7 +59,7 @@ def load_model( device = parse_device_spec(device) processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True, revision=revision) - if optimization_strategy == OptimizationStrategy.LORA: + if optimization_strategy in (OptimizationStrategy.LORA, OptimizationStrategy.QLORA): default_params = DEFAULT_FLORENCE2_PEFT_PARAMS if peft_advanced_params is not None: default_params.update(peft_advanced_params) @@ -71,13 +72,30 @@ def load_model( else: logger.info("No LoRA parameters provided. Using default configuration.") config = LoraConfig(**default_params) + + bnb_config = None + if optimization_strategy == OptimizationStrategy.QLORA: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + logger.info("Using 4-bit quantization") + model = AutoModelForCausalLM.from_pretrained( model_id_or_path, revision=revision, trust_remote_code=True, cache_dir=cache_dir, + quantization_config=bnb_config, + device_map="auto" if optimization_strategy == OptimizationStrategy.QLORA else None, ) model = get_peft_model(model, config).to(device) + + if optimization_strategy == OptimizationStrategy.QLORA: + model = model.to(device) + model.print_trainable_parameters() else: model = AutoModelForCausalLM.from_pretrained( From 36a1ed04b9eaf085b5061fe9b96e84f8582882fb Mon Sep 17 00:00:00 2001 From: 0xD4rky Date: Tue, 10 Jun 2025 19:35:17 +0530 Subject: [PATCH 2/3] changes in some configs --- maestro/trainer/models/florence_2/core.py | 2 +- maestro/trainer/models/florence_2/entrypoint.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/maestro/trainer/models/florence_2/core.py b/maestro/trainer/models/florence_2/core.py index a3c61f5b..de56dd0b 100644 --- a/maestro/trainer/models/florence_2/core.py +++ b/maestro/trainer/models/florence_2/core.py @@ -92,7 +92,7 @@ class Florence2Configuration: model_id: str = DEFAULT_FLORENCE2_MODEL_ID revision: str = DEFAULT_FLORENCE2_MODEL_REVISION device: str | torch.device = "auto" - optimization_strategy: Literal["lora", "freeze", "none"] = "lora" + optimization_strategy: Literal["lora", "qlora", "freeze", "none"] = "lora" cache_dir: Optional[str] = None epochs: int = 10 lr: float = 1e-5 diff --git a/maestro/trainer/models/florence_2/entrypoint.py b/maestro/trainer/models/florence_2/entrypoint.py index 87be2459..337dbc07 100644 --- a/maestro/trainer/models/florence_2/entrypoint.py +++ b/maestro/trainer/models/florence_2/entrypoint.py @@ -35,7 +35,7 @@ def train( ] = DEFAULT_FLORENCE2_MODEL_REVISION, device: Annotated[str, typer.Option("--device", help="Device to use for training")] = "auto", optimization_strategy: Annotated[ - str, typer.Option("--optimization_strategy", help="Optimization strategy: lora, freeze, or none") + str, typer.Option("--optimization_strategy", help="Optimization strategy: lora, qlora, freeze, or none") ] = "lora", cache_dir: Annotated[ Optional[str], typer.Option("--cache_dir", help="Directory to cache the model weights locally") From 80b9a72d6d96eeeb468f1dd5a3c2be4398f6ed2e Mon Sep 17 00:00:00 2001 From: 0xD4rky Date: Tue, 10 Jun 2025 20:01:02 +0530 Subject: [PATCH 3/3] fixing write permissions --- .github/workflows/welcome.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/welcome.yml b/.github/workflows/welcome.yml index 4480b1f2..528ba522 100644 --- a/.github/workflows/welcome.yml +++ b/.github/workflows/welcome.yml @@ -6,6 +6,10 @@ on: pull_request_target: types: [opened] +permissions: + pull-requests: write + issues: write + jobs: build: name: 👋 Welcome