-
Notifications
You must be signed in to change notification settings - Fork 273
FP8: Load model on-the-fly in vLLM #380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -845,7 +845,13 @@ def vllm_dynamic_quant_supported( | |||||
| pass | ||||||
|
|
||||||
|
|
||||||
| def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision_model = False): | ||||||
| def get_vllm_state_dict( | ||||||
| llm, | ||||||
| return_state_dict = False, | ||||||
| config = None, | ||||||
| is_vision_model = False, | ||||||
| load_in_fp8 = False, | ||||||
| ): | ||||||
| # If the vllm state dict was quantized using torchao, we will run into | ||||||
| # the following error when calling ops like aten.t() in inference mode. | ||||||
| # This is a bug in PyTorch that affects all tensor subclasses. | ||||||
|
|
@@ -854,7 +860,7 @@ def get_vllm_state_dict(llm, return_state_dict = False, config = None, is_vision | |||||
| # | ||||||
| # For now, we work around this issue by using torch.no_grad in this case. | ||||||
| # See https://github.com/pytorch/pytorch/issues/164872 for more details | ||||||
| if get_quant_type(config) == "torchao": | ||||||
| if get_quant_type(config) == "torchao" or load_in_fp8: | ||||||
| ctx_manager = torch.no_grad() | ||||||
| else: | ||||||
| ctx_manager = torch.inference_mode() | ||||||
|
Comment on lines
+863
to
866
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When a model is quantized via the new Useful? React with 👍 / 👎. |
||||||
|
|
@@ -1706,6 +1712,30 @@ def _module_disallows_flashinfer(module) -> bool: | |||||
| pass | ||||||
|
|
||||||
|
|
||||||
| def _get_torchao_fp8_config(fp8_mode: str): | ||||||
| """ | ||||||
| Return a `torchao.quantization.Float8DynamicActivationFloat8WeightConfig` | ||||||
| to be used for `load_in_fp8=True`. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring mentions
Suggested change
|
||||||
| """ | ||||||
| from torchao.quantization import ( | ||||||
| Float8DynamicActivationFloat8WeightConfig, | ||||||
| PerBlock, | ||||||
| PerRow, | ||||||
| ) | ||||||
|
|
||||||
| if fp8_mode == "row": | ||||||
| granularity = PerRow() | ||||||
| elif fp8_mode == "block": | ||||||
| granularity = (PerBlock([1, 128]), PerBlock([128, 128])) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| else: | ||||||
| raise ValueError("Unsloth: `load_in_fp8` supports only 'row' or 'block'") | ||||||
|
|
||||||
| return Float8DynamicActivationFloat8WeightConfig( | ||||||
| granularity = granularity, | ||||||
| activation_value_lb = 1e-12, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| def load_vllm( | ||||||
| model_name : str = "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", | ||||||
| config = None, | ||||||
|
|
@@ -1731,6 +1761,7 @@ def load_vllm( | |||||
| is_vision_model : bool = False, | ||||||
| return_args : bool = False, # Just return args | ||||||
| max_num_seqs : int = 256, # how many seqs to process in parallel. Default vLLM 256 | ||||||
| fp8_mode : Optional[str] = None, | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new It is highly recommended to add tests for this new functionality to ensure its correctness and prevent future regressions. |
||||||
| ): | ||||||
| # All Unsloth Zoo code licensed under LGPLv3 | ||||||
| # Create vLLM instance | ||||||
|
|
@@ -2171,6 +2202,18 @@ def load_vllm( | |||||
| engine_args["disable_cascade_attn"] = disable_cascade_attn | ||||||
| pass | ||||||
|
|
||||||
| # On-the-fly quantization is added in https://github.com/vllm-project/vllm/pull/23014 | ||||||
| # This is only available in vllm >= 0.12.0. Older versions will do offline quantization | ||||||
| # by creating an FP8 checkpoint and loading it back in | ||||||
| if fp8_mode is not None and Version(vllm_version) >= Version("0.12.0"): | ||||||
| from torchao.core.config import config_to_dict | ||||||
| torchao_config = _get_torchao_fp8_config(fp8_mode) | ||||||
| hf_overrides = { | ||||||
| "quantization_config_dict_json": json.dumps(config_to_dict(torchao_config)), | ||||||
| } | ||||||
| engine_args["quantization"] = "torchao" | ||||||
|
Comment on lines
+2210
to
+2214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the fp8 path we set Useful? React with 👍 / 👎. |
||||||
| engine_args["hf_overrides"] = hf_overrides | ||||||
|
Comment on lines
+2208
to
+2215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Consider adding an |
||||||
|
|
||||||
| good_keys = inspect.signature(AsyncEngineArgs if use_async else EngineArgs).parameters.keys() | ||||||
| old_keys = list(engine_args.keys()) | ||||||
| for key in old_keys: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this strictly necessary? I don't remember this being needed for offline quant FP8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this was actually necessary for offline FP8 quant as well. It's just that now we don't require the config to explicitly say "torchao FP8" (it can be any bf16 checkpoint), and instead we handle the FP8 quantization dynamically through vllm, but we still need
no_gradcause it's still using tensor subclasses