Skip to content

Commit a889be3

Browse files
committed
format
1 parent 98d2233 commit a889be3

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

python/mlc_chat/model/mixtral/mixtral_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,17 @@ def combine_expert_gate_up(*hf_params, dtype):
102102
dtype=mlc_param.dtype,
103103
),
104104
)
105-
105+
106106
mlc_name = f"{mlc_mlp}.gate.weight"
107107
mlc_param = named_parameters[mlc_name]
108108
mapping.add_mapping(
109109
mlc_name,
110-
[
111-
f"{mlp}.gate.weight"
112-
],
110+
[f"{mlp}.gate.weight"],
113111
functools.partial(
114112
lambda x, dtype: x.astype(dtype),
115113
dtype=mlc_param.dtype,
116114
),
117115
)
118-
119116

120117
# inv_freq is not used in the model
121118
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")
@@ -131,4 +128,3 @@ def combine_expert_gate_up(*hf_params, dtype):
131128
),
132129
)
133130
return mapping
134-

python/mlc_chat/model/mixtral/mixtral_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from mlc_chat.support.config import ConfigBase
1717
from mlc_chat.support.style import bold
1818
from mlc_chat.support import tensor_parallel as tp
19-
from mlc_chat.model.mistral.mistral_model import MistralConfig, RotaryEmbedding, MistralAttention, MistralModel, MistralForCasualLM
19+
from mlc_chat.model.mistral.mistral_model import (
20+
MistralConfig,
21+
RotaryEmbedding,
22+
MistralAttention,
23+
MistralModel,
24+
MistralForCasualLM,
25+
)
2026

2127
logger = logging.getLogger(__name__)
2228

@@ -31,6 +37,7 @@ class MixtralConfig(MistralConfig): # pylint: disable=too-many-instance-attribu
3137
def __post_init__(self):
3238
super().__post_init__()
3339

40+
3441
# pylint: disable=invalid-name,missing-docstring
3542

3643

@@ -187,8 +194,9 @@ def forward(self, x: Tensor):
187194
)
188195
return weighted_sum
189196

197+
190198
class MixtralDecoderLayer(nn.Module):
191-
""" Mixtral decoder layer"""
199+
"""Mixtral decoder layer"""
192200

193201
def __init__(self, config: MixtralConfig, rotary_embedding: RotaryEmbedding):
194202
rms_norm_eps = config.rms_norm_eps
@@ -253,6 +261,7 @@ def __init__(self, config: MixtralConfig):
253261
[MixtralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
254262
)
255263

264+
256265
class MixtralForCasualLM(MistralForCasualLM):
257266
"""Same as LlamaForCausalLM, except for the use of sliding window attention."""
258267

0 commit comments

Comments
 (0)