Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
833e147
[WIP] use vLLM for vision language models
Datta0 Jul 6, 2025
bad1692
Streamline vision vllm settings
Datta0 Jul 7, 2025
985e81c
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Jul 10, 2025
5883e13
WIP
Datta0 Jul 10, 2025
d23e378
WIP vLLM VLM
Datta0 Jul 10, 2025
beba3ae
Make individual dummy model for qwen 2.5vl, llama3.2,
Datta0 Jul 12, 2025
d124be6
fixup norm for vLLM
Datta0 Jul 13, 2025
7abcb47
rework vLLM for VLMs
Datta0 Jul 13, 2025
11e3ff0
Cleanup more stuff
Datta0 Jul 13, 2025
b043e73
Load up remaining modules from state dict
Datta0 Jul 14, 2025
125597e
use get_state_dict when possible
Datta0 Jul 14, 2025
500fc02
Fixup lm_head state dict fetch
Datta0 Jul 15, 2025
fab2ba0
add is_vision flag for differentiating VLMs
Datta0 Jul 15, 2025
9d0a7e2
add is_vision_model flag
Datta0 Jul 15, 2025
3a72f8f
Cleanup more stuff
Datta0 Jul 15, 2025
872127f
Cleanup vLLM extraction
Datta0 Jul 15, 2025
090df5d
Fixup device type
Datta0 Jul 15, 2025
c1b57fd
Cleanup more stuff
Datta0 Jul 15, 2025
27e8b18
revert vLLM mem usage calc changes
Datta0 Jul 15, 2025
60d3a9c
Populate config values properly for VLMs
Datta0 Jul 16, 2025
4b054b8
cleaner attribute copy and check mechanism
Datta0 Jul 17, 2025
e021682
Patch siglip empty init
Datta0 Jul 17, 2025
4e91e37
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Jul 17, 2025
6d5f448
Make additional module loading memory efficient
Datta0 Jul 17, 2025
e720866
Let the mini models be really small
Datta0 Jul 17, 2025
544cf2e
Minor cleanup
Datta0 Jul 17, 2025
b466139
cleanup vllm_utils by moving out empty model creation
Datta0 Jul 18, 2025
b5e8d63
Gemma3 and CausalLM fixes
Datta0 Jul 18, 2025
e7e4279
Merge branch 'main' into vlm_fast_infer
Datta0 Jul 18, 2025
f4cc238
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Jul 20, 2025
c230e6d
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Jul 21, 2025
9838d9b
Respect vLLMs conditions of max_num_batch_tokens vs max_seq_len
Datta0 Jul 21, 2025
638bd10
Restrict mm per prompt and max batch tokens
Datta0 Jul 22, 2025
de91982
Improve config copy overs
Datta0 Jul 22, 2025
64f40f3
Falcon H1 training is fp16 is unstable with the mamba kernels. NaN's …
mmathew23 Jul 22, 2025
292f6f7
Fix torch compile issues (#213)
danielhanchen Jul 23, 2025
1064908
Small fix
danielhanchen Jul 23, 2025
8fa08ed
fixup norms for causallm
Datta0 Jul 24, 2025
1b493e8
Guard against args change
Datta0 Jul 25, 2025
8fd2e4d
dont mark as grpo hidden states as dynamic
Datta0 Jul 28, 2025
deb2e0a
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Jul 29, 2025
34da39f
Refactor to make vision handling easier
Datta0 Aug 11, 2025
9a467f2
[WIP] fixup llama vision
Datta0 Aug 11, 2025
7d4db12
cleanup
Datta0 Aug 11, 2025
e0ebcc4
2/n mllama
Datta0 Aug 11, 2025
0de564f
fixup mllama additional layers
Datta0 Aug 12, 2025
bb6243f
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Aug 18, 2025
fa47fdf
Fixup qwen qknorm
Datta0 Aug 18, 2025
c2da34a
Pad token check and state dict changes
Datta0 Aug 18, 2025
fa93268
Patch TF protobuf incompatability
Datta0 Aug 19, 2025
b71e4f5
Revert "Patch TF protobuf incompatability"
Datta0 Aug 19, 2025
af94f0c
Fixup patch_model_and_tokenizer for VLM
Datta0 Aug 19, 2025
e580d66
reset vllm state dict changes
Datta0 Aug 19, 2025
2c52a23
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Aug 23, 2025
28aae16
Cleanup logs
Datta0 Aug 27, 2025
85b26f3
Fixup gemma3 local rope embedding
Datta0 Aug 29, 2025
540c3d4
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Aug 30, 2025
c3d3ac9
Fix Qwen 2.5 VL gate_up_proj vLLM
Datta0 Aug 30, 2025
8c1034a
Wakeup before doing vLLM generate (#259)
Datta0 Sep 1, 2025
538ba0c
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Sep 5, 2025
cca9e16
use logger instead of print. Add license header
Datta0 Sep 8, 2025
6cfb2c9
Increase gpu_emmory_utilisation if in standby
Datta0 Sep 8, 2025
e078cf0
User friendly error message for sleep model with expandable segments
Datta0 Sep 8, 2025
41c7d41
Fixup cumem init for older versions
Datta0 Sep 9, 2025
be75f93
Merge remote-tracking branch 'origin/main' into vlm_fast_infer
Datta0 Sep 9, 2025
19519f1
fixup qwen vl vision rope
Datta0 Sep 9, 2025
b0a1081
do not slice logits for grpo
Datta0 Sep 9, 2025
cfae834
undo changes to rl_replacements
Datta0 Sep 9, 2025
f55abbe
Fix: (temporary workaround) mem usage calcl for quantized VLMs
Datta0 Sep 10, 2025
71c78bf
Add mistral 3 support
Datta0 Sep 15, 2025
07b0459
Cleanup and fix for other models
Datta0 Sep 16, 2025
f6ed07d
fixup comparison attributes
Datta0 Sep 16, 2025
face6d6
Testing code
Datta0 Sep 16, 2025
67354e9
Merge remote-tracking branch 'origin' into vlm_fast_test
Datta0 Sep 16, 2025
2b67460
Merge remote-tracking branch 'origin' into vlm_fast_infer
Datta0 Sep 16, 2025
ae65c51
compare and copy dtype
Datta0 Sep 16, 2025
ee0d83a
Merge branch 'vlm_fast_infer' into vlm_fast_test
Datta0 Sep 16, 2025
bd06828
Merge remote-tracking branch 'origin' into mistral3_vllm
Datta0 Sep 16, 2025
93b9fa2
more mistral changes
Datta0 Sep 16, 2025
0dddcc4
Merge branch 'vlm_fast_test' into mistral3_vllm
Datta0 Sep 16, 2025
4a8c9d5
Merge remote-tracking branch 'origin' into mistral3_vllm
Datta0 Sep 16, 2025
f3b8263
more mistral changes
Datta0 Sep 16, 2025
c122767
Mistral 3 final touches
Datta0 Sep 16, 2025
d7461d6
Fixup mistral3 quantization stuff
Datta0 Sep 17, 2025
71718e8
Clean up stuff
Datta0 Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 76 additions & 42 deletions unsloth_zoo/empty_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def compare_attributes(original_model, new_model):
new_attrs = {attr for attr in dir(module) if not attr.startswith('_')}
buffer_names = {name for name,_ in original_module.named_buffers(recurse=False)}

assert type(module) == type(original_module), f"Type mismatch for {name}: {type(module)} != {type(original_module)}"

# Find missing attributes (in original but not in new)
missing_in_new = orig_attrs - new_attrs
missing_in_new = missing_in_new - {'hf_device_map'}
missing_in_new = missing_in_new - {'hf_device_map', 'source_cls'}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this is a set, wrap in set ie set({...})

Copy link
Copy Markdown
Collaborator Author

@Datta0 Datta0 Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is already a set right? I don't think we need to do a wrap

if missing_in_new:
for attr in sorted(missing_in_new):
missing_attrs.append(f"{name}.{attr}")
Expand Down Expand Up @@ -240,11 +239,7 @@ def create_empty_causal_lm(config, dtype = torch.float16):
attn_implementation = "eager",
)

# Get layer names from config
layer_config = get_model_layer_config()
layer_names = sum(layer_config.values(), [])

return new_model, original_meta_model, layer_names, config.num_hidden_layers
return new_model, original_meta_model, config.num_hidden_layers

def _set_config_attrs(config_obj, attrs_to_set):
"""Helper to set multiple attributes on a config object if they exist."""
Expand Down Expand Up @@ -326,20 +321,23 @@ def _init_weights(self, module):
num_layers = max(text_layers, vision_layers)
new_model = model_cls(new_config)

# Get layer names from config
layer_config = get_model_layer_config()
layer_names = sum(layer_config.values(), [])

return new_model, original_meta_model, layer_names, num_layers
return new_model, original_meta_model, num_layers


@torch.inference_mode()
def create_empty_model(config, dtype = torch.float16, is_vision_model = False):
# All Unsloth Zoo code licensed under LGPLv3
if is_vision_model:
return create_empty_vision_model(config, dtype)
new_model, original_meta_model, num_layers = create_empty_vision_model(config, dtype)
else:
return create_empty_causal_lm(config, dtype)
new_model, original_meta_model, num_layers = create_empty_causal_lm(config, dtype)

# Get layer names from config
layer_templates = get_model_layer_config(return_non_layered=False)
layer_names = sum(layer_templates.values(), [])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait sum?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh basically this is like join but I wanted a single deflated list. So sum everything with empty list and you get a single list of everything :)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its like a=[1,2,3]; b = [4,5,6] and doing a+b which is [1,2,3,4,5,6] or something


return new_model, original_meta_model, num_layers, layer_names


@torch.inference_mode()
def set_additional_modules(new_model, quant_state_dict, config):
Expand Down Expand Up @@ -378,8 +376,11 @@ def set_additional_modules(new_model, quant_state_dict, config):
norm = torch.nn.Parameter(norm, requires_grad = False)
language_model.norm.weight = norm

# LM Head
if getattr(config, "tie_word_embeddings", False):
# LM Head. Do note that for some models, like Mistral3ForConditionalGeneration,
# there can be mismatch in the value of tie_word_embeddings between config and text_config
# we prefer picking the one in text_config. If you notice any issue later, please report it!
text_config = getattr(config, "text_config", config)
if getattr(text_config, "tie_word_embeddings", False):
lmhead_key = f"{language_model_prefix}.embed_tokens.weight"
else:
lmhead_key = "lm_head.weight"
Expand Down Expand Up @@ -422,19 +423,20 @@ def set_additional_modules(new_model, quant_state_dict, config):
)

for key in additional_keys:
try:
replaced_key = re.sub(r"\.(\d+)\.", r"[\1].", key)
exec(f"new_{replaced_key}.data = quant_state_dict[key]")
except:
try:
# sometimes it can be in new_model.model. instead of new_model.
exec(f"new_model.{replaced_key}.data = quant_state_dict[key]")
except:
continue
replaced_key = re.sub(r"\.(\d+)\.", r"[\1].", key)
# sometimes it can be in new_model.model. instead of new_model.
for prefix in ['new_', 'new_model.']:
for suffix in ['', '.data']:
try:
exec(f"{prefix}{replaced_key}{suffix} = quant_state_dict[key]")
break
except:
continue

pass
pass

def get_model_layer_config():
def get_model_layer_config(return_non_layered=True):
"""
Returns a unified layer configuration containing the union of layer names
from all supported vision models. Serves as a fallback.
Expand Down Expand Up @@ -480,6 +482,10 @@ def get_model_layer_config():
"model.vision_tower.vision_model.encoder.layers.{kk}.post_layernorm",
"model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm1",
"model.vision_tower.vision_model.encoder.layers.{kk}.layer_norm2",

# Mistral3 vision norms
"model.vision_tower.transformer.layers.{kk}.attention_norm",
"model.vision_tower.transformer.layers.{kk}.ffn_norm",
},
'vision_layers': {

Expand Down Expand Up @@ -533,15 +539,33 @@ def get_model_layer_config():
"model.visual.blocks.{kk}.mlp.up_proj",
"model.visual.blocks.{kk}.mlp.down_proj",

# Mistral 3
"model.vision_tower.transformer.layers.{kk}.attention.q_proj",
"model.vision_tower.transformer.layers.{kk}.attention.k_proj",
"model.vision_tower.transformer.layers.{kk}.attention.v_proj",
"model.vision_tower.transformer.layers.{kk}.attention.qkv_proj",
"model.vision_tower.transformer.layers.{kk}.attention.o_proj",
"model.vision_tower.transformer.layers.{kk}.feed_forward.gate_up_proj",
"model.vision_tower.transformer.layers.{kk}.feed_forward.gate_proj",
"model.vision_tower.transformer.layers.{kk}.feed_forward.up_proj",
"model.vision_tower.transformer.layers.{kk}.feed_forward.down_proj",

},
'additional_layers': {
"model.visual.merger.mlp.{kk}",
"model.visual.merger.mlp.{kk}",
'model.language_model.model.layers.{kk}.cross_attn_mlp_gate',
'model.language_model.model.layers.{kk}.cross_attn_attn_gate',
'model.vision_model.global_transformer.layers.{kk}.gate_ffn',

# Mistral3
"model.multi_modal_projector.patch_merger.merging_layer",
"model.multi_modal_projector.linear_1",
"model.multi_modal_projector.linear_2",
},
"non_layered_components":{
# we do not handle quantization for these layers yet
# the set_additional_modules would process these layers
"model.multi_modal_projector",
"model.language_model.norm",
'model.vision_model.layernorm_pre',
Expand All @@ -561,11 +585,17 @@ def get_model_layer_config():
"model.vision_model.pre_tile_positional_embedding.embedding",
"model.vision_model.gated_positional_embedding",
"model.vision_model.post_tile_positional_embedding.embedding",
"model.vision_model.pre_tile_positional_embedding.gate"
"model.vision_model.pre_tile_positional_embedding.gate",

# Mistral3
"model.vision_tower.patch_positional_embedding",
"model.vision_tower.patch_conv",
"model.vision_tower.ln_pre",
}
}

# Convert sets to sorted lists for deterministic order
return {key: sorted(list(value)) for key, value in layer_templates.items()}
return {key: sorted(list(value)) for key, value in layer_templates.items() if key!='non_layered_components' or return_non_layered}


def get_model_layer_counts(config):
Expand Down Expand Up @@ -648,13 +678,13 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat


if layer_module is not None:
if "qkv_proj" in layer_path:
if model_type in ["mllama", "gemma3"]:
if "qkv" in layer_path:
if model_type == "qwen2_5_vl":
get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False)
else:
get_state_dict(f"{layer_path.replace('qkv_proj', 'q_proj')}", 0, state_dict, layer_module)
get_state_dict(f"{layer_path.replace('qkv_proj', 'k_proj')}", 1, state_dict, layer_module)
get_state_dict(f"{layer_path.replace('qkv_proj', 'v_proj')}", 2, state_dict, layer_module)
elif model_type == "qwen2_5_vl":
get_state_dict(layer_path, 0, state_dict, layer_module, slice_weights=False)
elif "gate_up_proj" in layer_path:
# vLLM seems to have merged gate and up proj recently for qwen vl. This is to handle new variant
# https://github.com/jeejeelee/vllm/commit/a71e4765cc0c1534f2a8891aaf628e1751f6df07
Expand All @@ -665,11 +695,7 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat
else: # Handle other layers, especially layernorms
if isinstance(layer_module, torch.nn.Module):
if hasattr(layer_module, 'weight'):
state_dict[f"{layer_path}.weight"] = layer_module.weight.data
quant_state_dict[f"{layer_path}.weight"] = state_dict[f"{layer_path}.weight"]
if hasattr(layer_module, 'bias') and layer_module.bias is not None:
state_dict[f"{layer_path}.bias"] = layer_module.bias.data
quant_state_dict[f"{layer_path}.bias"] = state_dict[f"{layer_path}.bias"]
get_state_dict(layer_path, 0, state_dict, layer_module)
elif isinstance(layer_module, torch.nn.Parameter):
state_dict[f"{layer_path}"] = layer_module.data
quant_state_dict[f"{layer_path}"] = state_dict[f"{layer_path}"]
Expand All @@ -682,14 +708,22 @@ def extract_vision_layers(vllm_internals, state_dict, quant_state_dict, get_stat
component = _get_nested_attr(vllm_internals, component_path)

if component is not None:
if isinstance(component, torch.nn.Module):
for param_name, param in component.named_parameters():
full_param_path = f"{component_path}.{param_name}"
state_dict[full_param_path] = param.data
quant_state_dict[full_param_path] = param.data
if hasattr(component, 'weight'):
# Prefer using get_state_dict when possible
get_state_dict(component_path, 0, state_dict, component)
elif isinstance(component, torch.nn.Parameter):
state_dict[component_path] = component.data
quant_state_dict[component_path] = component.data
elif isinstance(component, torch.nn.Module):
for param_name, param in component.named_parameters():
# if the parameter is to be extracted separately, skip it
if param_name.replace('.weight', '') in non_layered_components: continue
full_param_path = f"{component_path}.{param_name}"
if hasattr(param, 'weight'):
get_state_dict(full_param_path, 0, state_dict, param)
elif hasattr(param, 'data'):
state_dict[full_param_path] = param.data
quant_state_dict[full_param_path] = param.data
else:
print(f"Unsloth: Skipping non-layered component '{component_path}' of unexpected type: {type(component)}")

Expand Down
Loading