Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,52 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def _convert_weights_to_fp8(
self, state_dict: Dict[str, torch.tensor],
weights_to_convert: List[str]) -> Dict[str, torch.tensor]:
"""Converts the original weights to FP8E4M3 from FP16."""
for weight_name in weights_to_convert:
weight_scale_name = weight_name + "_scale"
if weight_scale_name not in state_dict:
continue
loaded_weight = state_dict[weight_name]
scale = state_dict[weight_scale_name]
state_dict[weight_name] = (loaded_weight.cpu() / scale.cpu()).to(
torch.float8_e4m3fn)
return state_dict

def _convert_scales_for_vllm(self, key, value):
"""Replaces the names of *quantizer._amax to _scale."""
replacements = {
"weight_quantizer._amax": "weight_scale",
"input_quantizer._amax": "act_scale",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just tweaked this to input_scale FYI ahead of the v0.5.0 beta launch

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated such that act_scale -> input_scale #5353

}
for old_suffix, new_suffix in replacements.items():
if key.endswith(old_suffix):
new_key = key[:len(key) - len(old_suffix)] + new_suffix
new_value = value / 448
return new_key, new_value
else:
return key, value

def _convert_ammo_weights(self, input_state_dict: Dict[str, torch.tensor]):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ammo is no the product name. Let's use modelopt instead.

"""Util method to modify the modelopt state dict to vLLM checkpoint."""
weights_to_convert = []
vllm_state_dict = {}
for key, value in input_state_dict.items():
if key.endswith("_amax"):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best to make this as specific as possible to avoid possible conflicts -- would if key.endswith("_quantizer._amax"): work?

new_key, new_value = self._convert_scales_for_vllm(key, value)
# Only add if the replacement happened.
if key != new_key:
vllm_state_dict[new_key] = new_value
else:
weights_to_convert.append(key)
vllm_state_dict[key] = value
# Conversion can only happen after all the amax values are read.
vllm_state_dict = self._convert_weights_to_fp8(vllm_state_dict,
weights_to_convert)
return vllm_state_dict

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand All @@ -396,7 +442,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
weights_copy = {}
for name, loaded_weight in weights:
weights_copy[name] = loaded_weight
print(name)

weights_ = self._convert_ammo_weights(weights_copy)
for name, loaded_weight in weights_.items():
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
Expand Down