Skip to content

Commit 2fb961f

Browse files
committed
stick to same convention as new llama
1 parent b44b917 commit 2fb961f

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

exo/inference/mlx/models/llava.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import mlx.nn as nn
1010
from mlx_lm.models.base import BaseModelArgs, KVCache
1111
from exo.inference.shard import Shard
12+
from .base import IdentityBlock
1213
import numpy as np
1314

1415

@@ -369,22 +370,25 @@ def __call__(
369370

370371

371372
class Llama(nn.Module):
372-
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
373+
def __init__(self, config: TextConfig, shard: Shard):
373374
super().__init__()
374375
self.config = config
375-
self.is_first_layer = is_first_layer
376-
self.is_last_layer = is_last_layer
376+
self.shard = shard
377377
self.vocab_size = config.vocab_size
378378
self.model_type = config.model_type
379379
self.num_hidden_layers = config.num_hidden_layers
380380
self.num_key_value_heads = config.num_key_value_heads
381381
self.head_dim = config.head_dim
382382
assert self.vocab_size > 0
383383
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
384-
self.layers = [
385-
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
386-
]
387-
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384+
self.layers = []
385+
for i in range(self.num_hidden_layers):
386+
if self.shard.start_layer <= i <= self.shard.end_layer:
387+
self.layers.append(TransformerBlock(config=config))
388+
else:
389+
self.layers.append(IdentityBlock())
390+
if self.shard.is_last_layer():
391+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
388392

389393
def __call__(
390394
self,
@@ -394,7 +398,7 @@ def __call__(
394398
):
395399
# for passing merged input embeddings
396400
if inputs_embeds is None:
397-
if self.is_first_layer:
401+
if self.shard.is_first_layer():
398402
h = self.embed_tokens(inputs)
399403
else:
400404
h = inputs
@@ -413,20 +417,20 @@ def __call__(
413417
for layer, c in zip(self.layers, cache):
414418
h = layer(h, mask, c)
415419

416-
if self.is_last_layer:
420+
if self.shard.is_last_layer():
417421
h = self.norm(h)
418422
return h
419423

420424
class LanguageModel(nn.Module):
421-
def __init__(self, config: TextConfig, is_first_layer, is_last_layer):
425+
def __init__(self, config: TextConfig, shard: Shard):
422426
super().__init__()
423427
self.model_type = config.model_type
424428
if self.model_type != "llama":
425429
raise ValueError(
426430
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
427431
)
428-
self.is_last_layer = is_last_layer
429-
self.model = Llama(config, is_first_layer, is_last_layer)
432+
self.shard = shard
433+
self.model = Llama(config, shard)
430434
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
431435

432436
def __call__(
@@ -436,7 +440,7 @@ def __call__(
436440
inputs_embeds=None,
437441
):
438442
out = self.model(inputs, cache, inputs_embeds)
439-
if self.is_last_layer:
443+
if self.shard.is_last_layer():
440444
out = self.lm_head(out)
441445
return out
442446

@@ -485,8 +489,6 @@ def __post_init__(self):
485489
if not self.shard.is_first_layer():
486490
self.vision_config = None
487491

488-
self.text_config.num_hidden_layers = self.shard.get_layer_count()
489-
490492

491493
class LlavaMultiModalProjector(nn.Module):
492494
def __init__(self, config: LlaVAConfig):
@@ -516,7 +518,7 @@ def __init__(self, config: ModelArgs):
516518
self.multi_modal_projector = LlavaMultiModalProjector(config)
517519
self.vision_feature_layer = config.vision_feature_layer
518520
self.vision_feature_select_strategy = config.vision_feature_select_strategy
519-
self.language_model = LanguageModel(config.text_config, config.shard.is_first_layer(), config.shard.is_last_layer())
521+
self.language_model = LanguageModel(config.text_config, config.shard)
520522

521523
def get_input_embeddings(
522524
self,

exo/inference/mlx/sharded_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def load_model_shard(
129129
class_predicate=None,
130130
)
131131

132-
model.load_weights(list(weights.items()))
132+
model.load_weights(list(weights.items()), strict=True)
133133

134134
if not lazy:
135135
mx.eval(model.parameters())

0 commit comments

Comments
 (0)