diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d83ee9a201c0..8fb943217b6a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -386,6 +386,52 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def _convert_weights_to_fp8( + self, state_dict: Dict[str, torch.tensor], + weights_to_convert: List[str]) -> Dict[str, torch.tensor]: + """Converts the original weights to FP8E4M3 from FP16.""" + for weight_name in weights_to_convert: + weight_scale_name = weight_name + "_scale" + if weight_scale_name not in state_dict: + continue + loaded_weight = state_dict[weight_name] + scale = state_dict[weight_scale_name] + state_dict[weight_name] = (loaded_weight.cpu() / scale.cpu()).to( + torch.float8_e4m3fn) + return state_dict + + def _convert_scales_for_vllm(self, key, value): + """Replaces the names of *quantizer._amax to _scale.""" + replacements = { + "weight_quantizer._amax": "weight_scale", + "input_quantizer._amax": "act_scale", + } + for old_suffix, new_suffix in replacements.items(): + if key.endswith(old_suffix): + new_key = key[:len(key) - len(old_suffix)] + new_suffix + new_value = value / 448 + return new_key, new_value + else: + return key, value + + def _convert_ammo_weights(self, input_state_dict: Dict[str, torch.tensor]): + """Util method to modify the modelopt state dict to vLLM checkpoint.""" + weights_to_convert = [] + vllm_state_dict = {} + for key, value in input_state_dict.items(): + if key.endswith("_amax"): + new_key, new_value = self._convert_scales_for_vllm(key, value) + # Only add if the replacement happened. + if key != new_key: + vllm_state_dict[new_key] = new_value + else: + weights_to_convert.append(key) + vllm_state_dict[key] = value + # Conversion can only happen after all the amax values are read. + vllm_state_dict = self._convert_weights_to_fp8(vllm_state_dict, + weights_to_convert) + return vllm_state_dict + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -396,7 +442,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".up_proj", 1), ] params_dict = dict(self.named_parameters()) + weights_copy = {} for name, loaded_weight in weights: + weights_copy[name] = loaded_weight + print(name) + + weights_ = self._convert_ammo_weights(weights_copy) + for name, loaded_weight in weights_.items(): if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name