Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,8 @@ def generate_chat_conv(
register_conv_template(
Conversation(
name="phi-4-mm",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
system_template="<|system|>{system_message}<|end|>",
system_message="",
system_template="{system_message}",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="<|end|>",
Expand Down
10 changes: 9 additions & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,12 @@ def normalize_gate_up_proj(
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.
else:
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)
25 changes: 16 additions & 9 deletions python/sglang/srt/models/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,23 +296,30 @@ def get_input_embeddings(self) -> nn.Embedding:
def compute_cu_seqlens(
self,
tgt_sizes: Optional[torch.Tensor] = None,
atch_attention_mask: Optional[torch.BoolTensor] = None,
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# shape: (batch_size,)
if tgt_sizes is not None:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
seqlen = tgt_sizes[:, 0] * tgt_sizes[:, 1]
elif input_embeds is not None:
seqlen = torch.full(
size=(input_embeds.shape[0],),
fill_value=input_embeds.shape[1],
dtype=torch.int32,
device=input_embeds.device,
)
else:
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[
:, 0, :
].sum(dim=1)
raise ValueError(
"Either `tgt_sizes` or `input_embeds` must be provided to compute cu_seqlens."
)

cu_seqlens = torch.cat(
[
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
torch.tensor([0], device=seqlen.device, dtype=torch.int32),
torch.cumsum(seqlen, dim=0, dtype=torch.int32),
],
dim=0,
).to(patch_len.device)
).to(seqlen.device)
return cu_seqlens

def forward(
Expand All @@ -326,7 +333,7 @@ def forward(
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask)
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, hidden_states)
encoder_outputs = self.encoder(
hidden_states,
cu_seqlens=cu_seqlens,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
return pattern.pad_input_tokens(input_ids, mm_inputs)

def should_apply_lora(self, module_name: str) -> Optional[str]:
return self.lora_pattern.match(module_name)
return bool(self.lora_pattern.match(module_name))

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
Loading