-
Notifications
You must be signed in to change notification settings - Fork 265
Mistral 3 Fast Inference #289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
833e147
bad1692
985e81c
5883e13
d23e378
beba3ae
d124be6
7abcb47
11e3ff0
b043e73
125597e
500fc02
fab2ba0
9d0a7e2
3a72f8f
872127f
090df5d
c1b57fd
27e8b18
60d3a9c
4b054b8
e021682
4e91e37
6d5f448
e720866
544cf2e
b466139
b5e8d63
e7e4279
f4cc238
c230e6d
9838d9b
638bd10
de91982
64f40f3
292f6f7
1064908
8fa08ed
1b493e8
8fd2e4d
deb2e0a
34da39f
9a467f2
7d4db12
e0ebcc4
0de564f
bb6243f
fa47fdf
c2da34a
fa93268
b71e4f5
af94f0c
e580d66
2c52a23
28aae16
85b26f3
540c3d4
c3d3ac9
8c1034a
538ba0c
cca9e16
6cfb2c9
e078cf0
41c7d41
be75f93
19519f1
b0a1081
cfae834
f55abbe
71c78bf
07b0459
f6ed07d
face6d6
67354e9
2b67460
ae65c51
ee0d83a
bd06828
93b9fa2
0dddcc4
4a8c9d5
f3b8263
c122767
d7461d6
71718e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,11 +61,10 @@ def compare_attributes(original_model, new_model): | |
| new_attrs = {attr for attr in dir(module) if not attr.startswith('_')} | ||
| buffer_names = {name for name,_ in original_module.named_buffers(recurse=False)} | ||
|
|
||
| assert type(module) == type(original_module), f"Type mismatch for {name}: {type(module)} != {type(original_module)}" | ||
|
|
||
| # Find missing attributes (in original but not in new) | ||
| missing_in_new = orig_attrs - new_attrs | ||
| missing_in_new = missing_in_new - {'hf_device_map'} | ||
| missing_in_new = missing_in_new - {'hf_device_map', 'source_cls'} | ||
| if missing_in_new: | ||
| for attr in sorted(missing_in_new): | ||
| missing_attrs.append(f"{name}.{attr}") | ||
|
|
@@ -240,11 +239,7 @@ def create_empty_causal_lm(config, dtype = torch.float16): | |
| attn_implementation = "eager", | ||
| ) | ||
|
|
||
| # Get layer names from config | ||
| layer_config = get_model_layer_config() | ||
| layer_names = sum(layer_config.values(), []) | ||
|
|
||
| return new_model, original_meta_model, layer_names, config.num_hidden_layers | ||
| return new_model, original_meta_model, config.num_hidden_layers | ||
|
|
||
| def _set_config_attrs(config_obj, attrs_to_set): | ||
| """Helper to set multiple attributes on a config object if they exist.""" | ||
|
|
@@ -326,20 +321,23 @@ def _init_weights(self, module): | |
| num_layers = max(text_layers, vision_layers) | ||
| new_model = model_cls(new_config) | ||
|
|
||
| # Get layer names from config | ||
| layer_config = get_model_layer_config() | ||
| layer_names = sum(layer_config.values(), []) | ||
|
|
||
| return new_model, original_meta_model, layer_names, num_layers | ||
| return new_model, original_meta_model, num_layers | ||
|
|
||
|
|
||
| @torch.inference_mode() | ||
| def create_empty_model(config, dtype = torch.float16, is_vision_model = False): | ||
| # All Unsloth Zoo code licensed under LGPLv3 | ||
| if is_vision_model: | ||
| return create_empty_vision_model(config, dtype) | ||
| new_model, original_meta_model, num_layers = create_empty_vision_model(config, dtype) | ||
| else: | ||
| return create_empty_causal_lm(config, dtype) | ||
| new_model, original_meta_model, num_layers = create_empty_causal_lm(config, dtype) | ||
|
|
||
| # Get layer names from config | ||
| layer_templates = get_model_layer_config(return_non_layered=False) | ||
| layer_names = sum(layer_templates.values(), []) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait sum?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh basically this is like join but I wanted a single deflated list. So sum everything with empty list and you get a single list of everything :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its like a=[1,2,3]; b = [4,5,6] and doing |
||
|
|
||
| return new_model, original_meta_model, num_layers, layer_names | ||
|
|
||
|
|
||
| @torch.inference_mode() | ||
| def set_additional_modules(new_model, quant_state_dict, config): | ||
|
|
@@ -378,8 +376,11 @@ def set_additional_modules(new_model, quant_state_dict, config): | |
| norm = torch.nn.Parameter(norm, requires_grad = False) | ||
| language_model.norm.weight = norm | ||
|
|
||
| # LM Head | ||
| if getattr(config, "tie_word_embeddings", False): | ||
| # LM Head. Do note that for some models, like Mistral3ForConditionalGeneration, | ||
| # there can be mismatch in the value of tie_word_embeddings between config and text_config | ||
| # we prefer picking the one in text_config. If you notice any issue later, please report it! | ||
| text_config = getattr(config, "text_config", config) | ||
| if getattr(text_config, "tie_word_embeddings", False): | ||
| lmhead_key = f"{language_model_prefix}.embed_tokens.weight" | ||
| else: | ||
| lmhead_key = "lm_head.weight" | ||
|
|
@@ -422,19 +423,20 @@ def set_additional_modules(new_model, quant_state_dict, config): | |
| ) | ||
|
|
||
| for key in additional_keys: | ||
| try: | ||
| replaced_key = re.sub(r"\.(\d+)\.", r"[\1].", key) | ||
| exec(f"new_{replaced_key}.data = quant_state_dict[key]") | ||
| except: | ||
| try: | ||
| # sometimes it can be in new_model.model. instead of new_model. | ||
| exec(f"new_model.{replaced_key}.data = quant_state_dict[key]") | ||
| except: | ||
| continue | ||
| replaced_key = re.sub(r"\.(\d+)\.", r"[\1].", key) | ||
| # sometimes it can be in new_model.model. instead of new_model. | ||
| for prefix in ['new_', 'new_model.']: | ||
| for suffix in ['', '.data']: | ||
| try: | ||
| exec(f"{prefix}{replaced_key}{suffix} = quant_state_dict[key]") | ||
| break | ||
| except: | ||
| continue | ||
|
|
||
| pass | ||
| pass | ||
|
|
||
| def get_model_layer_config(): | ||
| def get_model_layer_config(return_non_layered=True): | ||
| """ | ||
| Returns a unified layer configuration containing the union of layer names | ||
| from all supported vision models. Serves as a fallback. | ||
|
|
@@ -480,6 +482,10 @@ def get_model_layer_config(): | |
| "model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm", | ||
| "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm1", | ||
| "model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm2", | ||
|
|
||
| # Mistral3 vision norms | ||
| "model.vision_tower.transformer.layers.{kk}.attention_norm", | ||
| "model.vision_tower.transformer.layers.{kk}.ffn_norm", | ||
| }, | ||
| 'vision_layers': { | ||
|
|
||
|
|
@@ -533,15 +539,33 @@ def get_model_layer_config(): | |
| "model.visual.blocks.{kk}.mlp.up_proj", | ||
| "model.visual.blocks.{kk}.mlp.down_proj", | ||
|
|
||
| # Mistral 3 | ||
| "model.vision_tower.transformer.layers.{kk}.attention.q_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.attention.k_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.attention.v_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.attention.qkv_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.attention.o_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.feed_forward.gate_up_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.feed_forward.gate_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.feed_forward.up_proj", | ||
| "model.vision_tower.transformer.layers.{kk}.feed_forward.down_proj", | ||
|
|
||
| }, | ||
| 'additional_layers': { | ||
| "model.visual.merger.mlp.{kk}", | ||
| "model.visual.merger.mlp.{kk}", | ||
| 'model.language_model.model.layers.{kk}.cross_attn_mlp_gate', | ||
| 'model.language_model.model.layers.{kk}.cross_attn_attn_gate', | ||
| 'model.vision_model.global_transformer.layers.{kk}.gate_ffn', | ||
|
|
||
| # Mistral3 | ||
| "model.multi_modal_projector.patch_merger.merging_layer", | ||
| "model.multi_modal_projector.linear_1", | ||
| "model.multi_modal_projector.linear_2", | ||
| }, | ||
| "non_layered_components":{ | ||
| # we do not handle quantization for these layers yet | ||
| # the set_additional_modules would process these layers | ||
| "model.multi_modal_projector", | ||
| "model.language_model.norm", | ||
| 'model.vision_model.layernorm_pre', | ||
|
|
@@ -561,11 +585,17 @@ def get_model_layer_config(): | |
| "model.vision_model.pre_tile_positional_embedding.embedding", | ||
| "model.vision_model.gated_positional_embedding", | ||
| "model.vision_model.post_tile_positional_embedding.embedding", | ||
| "model.vision_model.pre_tile_positional_embedding.gate" | ||
| "model.vision_model.pre_tile_positional_embedding.gate", | ||
|
|
||
| # Mistral3 | ||
| "model.vision_tower.patch_positional_embedding", | ||
| "model.vision_tower.patch_conv", | ||
| "model.vision_tower.ln_pre", | ||
| } | ||
| } | ||
|
|
||
| # Convert sets to sorted lists for deterministic order | ||
| return {key: sorted(list(value)) for key, value in layer_templates.items()} | ||
| return {key: sorted(list(value)) for key, value in layer_templates.items() if key!='non_layered_components' or return_non_layered} | ||
|
|
||
|
|
||
| def get_model_layer_counts(config): | ||
|
|
@@ -648,13 +678,13 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat | |
|
|
||
|
|
||
| if layer_module is not None: | ||
| if "qkv_proj" in layer_path: | ||
| if model_type in ["mllama", "gemma3"]: | ||
| if "qkv" in layer_path: | ||
| if model_type == "qwen2_5_vl": | ||
| get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) | ||
| else: | ||
| get_state_dict(f"{layer_path.replace('qkv_proj', 'q_proj')}", 0, state_dict, layer_module) | ||
| get_state_dict(f"{layer_path.replace('qkv_proj', 'k_proj')}", 1, state_dict, layer_module) | ||
| get_state_dict(f"{layer_path.replace('qkv_proj', 'v_proj')}", 2, state_dict, layer_module) | ||
| elif model_type == "qwen2_5_vl": | ||
| get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False) | ||
| elif "gate_up_proj" in layer_path: | ||
| # vLLM seems to have merged gate and up proj recently for qwen vl. This is to handle new variant | ||
| # https://github.com/jeejeelee/vllm/commit/a71e4765cc0c1534f2a8891aaf628e1751f6df07 | ||
|
|
@@ -665,11 +695,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat | |
| else: # Handle other layers, especially layernorms | ||
| if isinstance(layer_module, torch.nn.Module): | ||
| if hasattr(layer_module, 'weight'): | ||
| state_dict[f"{layer_path}.weight"] = layer_module.weight.data | ||
| quant_state_dict[f"{layer_path}.weight"] = state_dict[f"{layer_path}.weight"] | ||
| if hasattr(layer_module, 'bias') and layer_module.bias is not None: | ||
| state_dict[f"{layer_path}.bias"] = layer_module.bias.data | ||
| quant_state_dict[f"{layer_path}.bias"] = state_dict[f"{layer_path}.bias"] | ||
| get_state_dict(layer_path, 0, state_dict, layer_module) | ||
| elif isinstance(layer_module, torch.nn.Parameter): | ||
| state_dict[f"{layer_path}"] = layer_module.data | ||
| quant_state_dict[f"{layer_path}"] = state_dict[f"{layer_path}"] | ||
|
|
@@ -682,14 +708,22 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat | |
| component = _get_nested_attr(vllm_internals, component_path) | ||
|
|
||
| if component is not None: | ||
| if isinstance(component, torch.nn.Module): | ||
| for param_name, param in component.named_parameters(): | ||
| full_param_path = f"{component_path}.{param_name}" | ||
| state_dict[full_param_path] = param.data | ||
| quant_state_dict[full_param_path] = param.data | ||
| if hasattr(component, 'weight'): | ||
| # Prefer using get_state_dict when possible | ||
| get_state_dict(component_path, 0, state_dict, component) | ||
| elif isinstance(component, torch.nn.Parameter): | ||
| state_dict[component_path] = component.data | ||
| quant_state_dict[component_path] = component.data | ||
| elif isinstance(component, torch.nn.Module): | ||
| for param_name, param in component.named_parameters(): | ||
| # if the parameter is to be extracted separately, skip it | ||
| if param_name.replace('.weight', '') in non_layered_components: continue | ||
| full_param_path = f"{component_path}.{param_name}" | ||
| if hasattr(param, 'weight'): | ||
| get_state_dict(full_param_path, 0, state_dict, param) | ||
| elif hasattr(param, 'data'): | ||
| state_dict[full_param_path] = param.data | ||
| quant_state_dict[full_param_path] = param.data | ||
| else: | ||
| print(f"Unsloth: Skipping non-layered component '{component_path}' of unexpected type: {type(component)}") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this is a set, wrap in set ie set({...})
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is already a set right? I don't think we need to do a wrap