-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[Model][Hardware][NV] Add support for ModelOpt static scaling checkpoints #5387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -386,6 +386,52 @@ def sample( | |
| next_tokens = self.sampler(logits, sampling_metadata) | ||
| return next_tokens | ||
|
|
||
| def _convert_weights_to_fp8( | ||
| self, state_dict: Dict[str, torch.tensor], | ||
| weights_to_convert: List[str]) -> Dict[str, torch.tensor]: | ||
| """Converts the original weights to FP8E4M3 from FP16.""" | ||
| for weight_name in weights_to_convert: | ||
| weight_scale_name = weight_name + "_scale" | ||
| if weight_scale_name not in state_dict: | ||
| continue | ||
| loaded_weight = state_dict[weight_name] | ||
| scale = state_dict[weight_scale_name] | ||
| state_dict[weight_name] = (loaded_weight.cpu() / scale.cpu()).to( | ||
| torch.float8_e4m3fn) | ||
| return state_dict | ||
|
|
||
| def _convert_scales_for_vllm(self, key, value): | ||
| """Replaces the names of *quantizer._amax to _scale.""" | ||
| replacements = { | ||
| "weight_quantizer._amax": "weight_scale", | ||
| "input_quantizer._amax": "act_scale", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been updated such that act_scale -> input_scale #5353 |
||
| } | ||
| for old_suffix, new_suffix in replacements.items(): | ||
| if key.endswith(old_suffix): | ||
| new_key = key[:len(key) - len(old_suffix)] + new_suffix | ||
| new_value = value / 448 | ||
| return new_key, new_value | ||
| else: | ||
| return key, value | ||
|
|
||
| def _convert_ammo_weights(self, input_state_dict: Dict[str, torch.tensor]): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ammo is no the product name. Let's use modelopt instead. |
||
| """Util method to modify the modelopt state dict to vLLM checkpoint.""" | ||
| weights_to_convert = [] | ||
| vllm_state_dict = {} | ||
| for key, value in input_state_dict.items(): | ||
| if key.endswith("_amax"): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be best to make this as specific as possible to avoid possible conflicts -- would |
||
| new_key, new_value = self._convert_scales_for_vllm(key, value) | ||
| # Only add if the replacement happened. | ||
| if key != new_key: | ||
| vllm_state_dict[new_key] = new_value | ||
| else: | ||
| weights_to_convert.append(key) | ||
| vllm_state_dict[key] = value | ||
| # Conversion can only happen after all the amax values are read. | ||
| vllm_state_dict = self._convert_weights_to_fp8(vllm_state_dict, | ||
| weights_to_convert) | ||
| return vllm_state_dict | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
| stacked_params_mapping = [ | ||
| # (param_name, shard_name, shard_id) | ||
|
|
@@ -396,7 +442,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| (".gate_up_proj", ".up_proj", 1), | ||
| ] | ||
| params_dict = dict(self.named_parameters()) | ||
| weights_copy = {} | ||
| for name, loaded_weight in weights: | ||
| weights_copy[name] = loaded_weight | ||
| print(name) | ||
|
|
||
| weights_ = self._convert_ammo_weights(weights_copy) | ||
| for name, loaded_weight in weights_.items(): | ||
| if "rotary_emb.inv_freq" in name: | ||
| continue | ||
| if ("rotary_emb.cos_cached" in name | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We just tweaked this to
input_scaleFYI ahead of the v0.5.0 beta launch