Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
from mlx_lm.models.base import (
create_attention_mask,
create_ssm_mask,
scaled_dot_product_attention,
)
from PIL import Image


Expand Down
2 changes: 2 additions & 0 deletions mlx_vlm/models/qwen3_5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .config import ModelConfig, TextConfig, VisionConfig
from .qwen3_5 import LanguageModel, Model, VisionModel
101 changes: 101 additions & 0 deletions mlx_vlm/models/qwen3_5/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import inspect
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

from ..base import BaseModelConfig
from ..qwen3_vl.config import VisionConfig as Qwen3VLVisionConfig


@dataclass
class VisionConfig(Qwen3VLVisionConfig):
model_type: str = "qwen3_5"

def __post_init__(self):
if (
self.deepstack_visual_indexes is not None
and len(self.deepstack_visual_indexes) > 0
):
raise ValueError(
f"deepstack is disabled for qwen3.5 temporally, but it is set to {self.deepstack_visual_indexes}"
)
self.deepstack_visual_indexes = []


@dataclass
class TextConfig(BaseModelConfig):
model_type: str
hidden_size: int
intermediate_size: int
linear_num_value_heads: int
linear_num_key_heads: int
linear_key_head_dim: int
linear_value_head_dim: int
linear_conv_kernel_dim: int
num_hidden_layers: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
max_position_embeddings: int
tie_word_embeddings: bool = False
attention_bias: bool = False
head_dim: Optional[int] = None
rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field(
default_factory=lambda: {
"type": "default",
"mrope_section": [11, 11, 10],
"rope_theta": 100000,
"partial_rotary_factor": 0.25,
}
)
full_attention_interval: int = 4

def __post_init__(self):
if self.rope_parameters:
# Normalize rope_parameters keys (accept both 'rope_type' and 'type')
if (
"type" not in self.rope_parameters
and "rope_type" in self.rope_parameters
):
self.rope_parameters["type"] = self.rope_parameters.pop("rope_type")

required_keys = {
"mrope_section",
"type",
"rope_theta",
"partial_rotary_factor",
}
if not all(key in self.rope_parameters for key in required_keys):
raise ValueError(f"rope_parameters must contain keys {required_keys}")


@dataclass
class ModelConfig(BaseModelConfig):
text_config: TextConfig
vision_config: VisionConfig
model_type: str
ignore_index: int = -100
image_token_id: int = 248056
video_token_id: int = 248057
image_token_index: Optional[int] = None
video_token_index: Optional[int] = None
vision_start_token_id: int = 248045
vision_end_token_id: int = 248046
vocab_size: int = 248320
eos_token_id: Optional[List[int]] = None

def __post_init__(self):
if self.image_token_index is None:
self.image_token_index = self.image_token_id
if self.video_token_index is None:
self.video_token_index = self.video_token_id

@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
Loading