9
9
import mlx .nn as nn
10
10
from mlx_lm .models .base import BaseModelArgs , KVCache
11
11
from exo .inference .shard import Shard
12
+ from .base import IdentityBlock
12
13
import numpy as np
13
14
14
15
@@ -369,22 +370,25 @@ def __call__(
369
370
370
371
371
372
class Llama (nn .Module ):
372
- def __init__ (self , config : TextConfig , is_first_layer , is_last_layer ):
373
+ def __init__ (self , config : TextConfig , shard : Shard ):
373
374
super ().__init__ ()
374
375
self .config = config
375
- self .is_first_layer = is_first_layer
376
- self .is_last_layer = is_last_layer
376
+ self .shard = shard
377
377
self .vocab_size = config .vocab_size
378
378
self .model_type = config .model_type
379
379
self .num_hidden_layers = config .num_hidden_layers
380
380
self .num_key_value_heads = config .num_key_value_heads
381
381
self .head_dim = config .head_dim
382
382
assert self .vocab_size > 0
383
383
self .embed_tokens = nn .Embedding (config .vocab_size , config .hidden_size )
384
- self .layers = [
385
- TransformerBlock (config = config ) for _ in range (config .num_hidden_layers )
386
- ]
387
- self .norm = nn .RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
384
+ self .layers = []
385
+ for i in range (self .num_hidden_layers ):
386
+ if self .shard .start_layer <= i <= self .shard .end_layer :
387
+ self .layers .append (TransformerBlock (config = config ))
388
+ else :
389
+ self .layers .append (IdentityBlock ())
390
+ if self .shard .is_last_layer ():
391
+ self .norm = nn .RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
388
392
389
393
def __call__ (
390
394
self ,
@@ -394,7 +398,7 @@ def __call__(
394
398
):
395
399
# for passing merged input embeddings
396
400
if inputs_embeds is None :
397
- if self .is_first_layer :
401
+ if self .shard . is_first_layer () :
398
402
h = self .embed_tokens (inputs )
399
403
else :
400
404
h = inputs
@@ -413,20 +417,20 @@ def __call__(
413
417
for layer , c in zip (self .layers , cache ):
414
418
h = layer (h , mask , c )
415
419
416
- if self .is_last_layer :
420
+ if self .shard . is_last_layer () :
417
421
h = self .norm (h )
418
422
return h
419
423
420
424
class LanguageModel (nn .Module ):
421
- def __init__ (self , config : TextConfig , is_first_layer , is_last_layer ):
425
+ def __init__ (self , config : TextConfig , shard : Shard ):
422
426
super ().__init__ ()
423
427
self .model_type = config .model_type
424
428
if self .model_type != "llama" :
425
429
raise ValueError (
426
430
f"Model type { self .model_type } not supported. Currently only 'llama' is supported"
427
431
)
428
- self .is_last_layer = is_last_layer
429
- self .model = Llama (config , is_first_layer , is_last_layer )
432
+ self .shard = shard
433
+ self .model = Llama (config , shard )
430
434
self .lm_head = nn .Linear (config .hidden_size , config .vocab_size , bias = False )
431
435
432
436
def __call__ (
@@ -436,7 +440,7 @@ def __call__(
436
440
inputs_embeds = None ,
437
441
):
438
442
out = self .model (inputs , cache , inputs_embeds )
439
- if self .is_last_layer :
443
+ if self .shard . is_last_layer () :
440
444
out = self .lm_head (out )
441
445
return out
442
446
@@ -485,8 +489,6 @@ def __post_init__(self):
485
489
if not self .shard .is_first_layer ():
486
490
self .vision_config = None
487
491
488
- self .text_config .num_hidden_layers = self .shard .get_layer_count ()
489
-
490
492
491
493
class LlavaMultiModalProjector (nn .Module ):
492
494
def __init__ (self , config : LlaVAConfig ):
@@ -516,7 +518,7 @@ def __init__(self, config: ModelArgs):
516
518
self .multi_modal_projector = LlavaMultiModalProjector (config )
517
519
self .vision_feature_layer = config .vision_feature_layer
518
520
self .vision_feature_select_strategy = config .vision_feature_select_strategy
519
- self .language_model = LanguageModel (config .text_config , config .shard . is_first_layer (), config . shard . is_last_layer () )
521
+ self .language_model = LanguageModel (config .text_config , config .shard )
520
522
521
523
def get_input_embeddings (
522
524
self ,
0 commit comments