Respect user quantization_config#3835
Conversation
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses an issue where user-defined quantization configurations were being inadvertently overwritten by default settings during model loading. The changes ensure that explicit user intent regarding quantization (e.g., BitsAndBytesConfig) and data types is fully respected across FastLanguageModel, FastModel, and vision model loading paths, preventing unexpected behavior and improving configuration reliability. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request does a good job of respecting user-provided quantization_config, which is a great improvement for usability. The changes in FastLanguageModel, FastModel, and the vision loading paths are all aligned with this goal. I've identified a critical bug where using a dictionary for quantization_config would cause a crash, and I've provided suggestions to fix it. I also noted some code duplication that could be refactored to improve maintainability. Overall, these are valuable changes.
| if getattr(quantization_config, "load_in_4bit", False): | ||
| load_in_4bit = True | ||
| load_in_8bit = False | ||
| if getattr(quantization_config, "load_in_8bit", False): | ||
| load_in_8bit = True | ||
| load_in_4bit = False |
There was a problem hiding this comment.
getattr will raise an AttributeError if quantization_config is a dictionary, which is a valid way to pass this configuration. This will cause a crash. The logic should handle both dict and object types, similar to how bnb_4bit_compute_dtype is handled later in the code.
| if getattr(quantization_config, "load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if getattr(quantization_config, "load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False | |
| if isinstance(quantization_config, dict): | |
| if quantization_config.get("load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if quantization_config.get("load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False | |
| else: | |
| if getattr(quantization_config, "load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if getattr(quantization_config, "load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False |
| if getattr(quantization_config, "load_in_4bit", False): | ||
| load_in_4bit = True | ||
| load_in_8bit = False | ||
| if getattr(quantization_config, "load_in_8bit", False): | ||
| load_in_8bit = True | ||
| load_in_4bit = False |
There was a problem hiding this comment.
Similar to the FastLanguageModel, getattr will raise an AttributeError if quantization_config is a dictionary. This will cause a crash. The logic should handle both dict and object types to prevent this.
| if getattr(quantization_config, "load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if getattr(quantization_config, "load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False | |
| if isinstance(quantization_config, dict): | |
| if quantization_config.get("load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if quantization_config.get("load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False | |
| else: | |
| if getattr(quantization_config, "load_in_4bit", False): | |
| load_in_4bit = True | |
| load_in_8bit = False | |
| if getattr(quantization_config, "load_in_8bit", False): | |
| load_in_8bit = True | |
| load_in_4bit = False |
| if dtype is None and quantization_config is not None: | ||
| bnb_compute_dtype = None | ||
| if isinstance(quantization_config, dict): | ||
| if quantization_config.get("load_in_4bit", False): | ||
| bnb_compute_dtype = quantization_config.get( | ||
| "bnb_4bit_compute_dtype", None | ||
| ) | ||
| else: | ||
| if getattr(quantization_config, "load_in_4bit", False): | ||
| bnb_compute_dtype = getattr( | ||
| quantization_config, "bnb_4bit_compute_dtype", None | ||
| ) | ||
| if isinstance(bnb_compute_dtype, str): | ||
| bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None) | ||
| if isinstance(bnb_compute_dtype, torch.dtype): | ||
| dtype = bnb_compute_dtype |
There was a problem hiding this comment.
This block of code to determine dtype from quantization_config is duplicated in FastModel.from_pretrained (lines 757-772). To improve maintainability and reduce redundancy, consider extracting this logic into a shared helper function. This will make the code easier to maintain and prevent potential inconsistencies in the future.
| if dtype is None and quantization_config is not None: | ||
| bnb_compute_dtype = None | ||
| if isinstance(quantization_config, dict): | ||
| if quantization_config.get("load_in_4bit", False): | ||
| bnb_compute_dtype = quantization_config.get( | ||
| "bnb_4bit_compute_dtype", None | ||
| ) | ||
| else: | ||
| if getattr(quantization_config, "load_in_4bit", False): | ||
| bnb_compute_dtype = getattr( | ||
| quantization_config, "bnb_4bit_compute_dtype", None | ||
| ) | ||
| if isinstance(bnb_compute_dtype, str): | ||
| bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None) | ||
| if isinstance(bnb_compute_dtype, torch.dtype): | ||
| dtype = bnb_compute_dtype |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: eddc4a8754
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| model_name = model_name, | ||
| max_seq_length = max_seq_length, | ||
| dtype = _get_dtype(dtype), | ||
| load_in_4bit = load_in_4bit, | ||
| load_in_4bit = load_in_4bit_kwargs, |
There was a problem hiding this comment.
Keep 4bit flag when forwarding quantization_config to vLLM
When a user provides a quantization_config, this call now forwards load_in_4bit_kwargs (set to False earlier whenever quantization_config is present) into the model loader. In the fast_inference path, FastLlamaModel uses the load_in_4bit flag to decide use_bitsandbytes for vLLM (unsloth/models/llama.py:2340-2351), and the quantization_config itself is ignored there. With fast_inference=True and a 4bit quantization_config, the model will now be loaded unquantized, leading to unexpected full-precision memory use or OOM. The 4bit flag should stay True for vLLM when the user requested 4bit quantization.
Useful? React with 👍 / 👎.
|
/gemini review |
There was a problem hiding this comment.
Code Review
The pull request successfully addresses the issue of user-provided quantization_config being overridden by default settings. The changes ensure that the user's explicit configuration is respected across different model loading paths, including for vision models. This is a good improvement for user experience and predictability.
My main feedback is regarding significant code duplication in unsloth/models/loader.py. The logic for processing quantization_config and aligning dtype is repeated in FastLanguageModel.from_pretrained and FastModel.from_pretrained. This should be refactored into a shared helper function to improve code maintainability and reduce redundancy.
| # Respect user-provided quantization_config (e.g. BitsAndBytesConfig) | ||
| quantization_config = kwargs.get("quantization_config", None) | ||
| if quantization_config is not None: | ||
| if isinstance(quantization_config, dict): | ||
| q_load_in_4bit = quantization_config.get("load_in_4bit", False) | ||
| q_load_in_8bit = quantization_config.get("load_in_8bit", False) | ||
| else: | ||
| q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False) | ||
| q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False) | ||
| if q_load_in_4bit: | ||
| load_in_4bit = True | ||
| load_in_8bit = False | ||
| if q_load_in_8bit: | ||
| load_in_8bit = True | ||
| load_in_4bit = False | ||
|
|
||
| # Login to allow private models | ||
| token = hf_login(token) | ||
| if whisper_language is not None: | ||
| assert type(whisper_language) is str | ||
| if whisper_task is not None: | ||
| assert type(whisper_task) is str | ||
| # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset. | ||
| if dtype is None and quantization_config is not None: | ||
| bnb_compute_dtype = None | ||
| if isinstance(quantization_config, dict): | ||
| if quantization_config.get("load_in_4bit", False): | ||
| bnb_compute_dtype = quantization_config.get( | ||
| "bnb_4bit_compute_dtype", None | ||
| ) | ||
| else: | ||
| if getattr(quantization_config, "load_in_4bit", False): | ||
| bnb_compute_dtype = getattr( | ||
| quantization_config, "bnb_4bit_compute_dtype", None | ||
| ) | ||
| if isinstance(bnb_compute_dtype, str): | ||
| bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None) | ||
| if isinstance(bnb_compute_dtype, torch.dtype): | ||
| dtype = bnb_compute_dtype |
There was a problem hiding this comment.
This block of code for handling quantization_config and aligning dtype is a duplicate of the logic in FastLanguageModel.from_pretrained (lines 154-188). Large blocks of duplicated code make the software harder to maintain, as any changes or bug fixes need to be applied in multiple places.
To improve maintainability, I suggest refactoring this logic into a single helper function. This function could take kwargs, dtype, load_in_4bit, and load_in_8bit as input and return the processed values.
Here's a potential implementation for the helper function:
def _process_user_quantization_config(kwargs, dtype, load_in_4bit, load_in_8bit):
quantization_config = kwargs.get("quantization_config", None)
if quantization_config is None:
return quantization_config, dtype, load_in_4bit, load_in_8bit
# Respect user-provided quantization_config
if isinstance(quantization_config, dict):
q_load_in_4bit = quantization_config.get("load_in_4bit", False)
q_load_in_8bit = quantization_config.get("load_in_8bit", False)
else:
q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False)
q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False)
if q_load_in_4bit:
load_in_4bit = True
load_in_8bit = False
elif q_load_in_8bit:
load_in_8bit = True
load_in_4bit = False
# Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.
if dtype is None:
bnb_compute_dtype = None
if isinstance(quantization_config, dict):
if quantization_config.get("load_in_4bit", False):
bnb_compute_dtype = quantization_config.get("bnb_4bit_compute_dtype", None)
else:
if getattr(quantization_config, "load_in_4bit", False):
bnb_compute_dtype = getattr(quantization_config, "bnb_4bit_compute_dtype", None)
if isinstance(bnb_compute_dtype, str):
import torch
bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)
if isinstance(bnb_compute_dtype, torch.dtype):
dtype = bnb_compute_dtype
return quantization_config, dtype, load_in_4bit, load_in_8bitYou could then call this helper from both from_pretrained methods. Note the use of elif which makes the precedence explicit if both 4-bit and 8-bit are specified in the config.
Respect user quantization_config
Summary
Motivation
User-supplied BitsAndBytesConfig was being overwritten with NF4 defaults, which caused unexpected quantization and dtype mismatches. This change preserves explicit user intent without changing defaults.
Testing