Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR

@register_pipeline("mlc_llm")
def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
cublas_gemm: bool,
cublas_gemm: bool = False,
variable_bounds: Dict[str, int] = None,
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
metadata: Dict[str, Any] = None,
Expand Down
20 changes: 11 additions & 9 deletions python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,18 @@ def _apply_preproc_to_params(

def _compile(args: CompileArgs, model_config: ConfigBase):
def _get_variable_bounds(model_config) -> Dict[str, int]:
variable_bounds = {"seq_len": model_config.prefill_chunk_size}
if hasattr(model_config, "sliding_window_size"):
variable_bounds["rolling_cache_len"] = model_config.sliding_window_size
variable_bounds["kv_seq_len"] = (
model_config.sliding_window_size + model_config.prefill_chunk_size,
)
else:
variable_bounds["total_seq_len"] = model_config.context_window_size
variable_bounds["batch_size"] = getattr(model_config, "max_batch_size", 1)
return variable_bounds
return {
"rolling_cache_len": model_config.sliding_window_size,
"kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size,
"seq_len": model_config.prefill_chunk_size,
"batch_size": getattr(model_config, "max_batch_size", 1),
}
return {
"total_seq_len": model_config.context_window_size,
"seq_len": model_config.prefill_chunk_size,
"batch_size": getattr(model_config, "max_batch_size", 1),
}

def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
return {
Expand Down
5 changes: 4 additions & 1 deletion python/mlc_chat/interface/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def _device_to_str(device: Device) -> str:


def _calc_total_params(model: nn.Module) -> int:
_, named_params, _ = model.export_tvm(spec=model.get_default_spec(), allow_extern=True)
_, named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(), # type: ignore[attr-defined]
allow_extern=True,
)
total_params = 0
for _, param in named_params:
total_params += math.prod(param.shape)
Expand Down
28 changes: 7 additions & 21 deletions python/mlc_chat/model/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
TODO: add docstring
"""
import dataclasses
import math
from typing import Any, Dict, Optional

from tvm import te, tir
from tvm.relax.frontend import nn
from tvm.relax.frontend.nn import Tensor, op

from mlc_chat import op as op_ext
from mlc_chat.support import logging
from mlc_chat.support.config import ConfigBase
from mlc_chat.support.style import bold
Expand Down Expand Up @@ -110,29 +110,15 @@ def forward(

self.k_cache.append(op.squeeze(k, axis=0))
self.v_cache.append(op.squeeze(v, axis=0))
k = op.reshape(self.k_cache.view(t), (b, t, h, d))
v = op.reshape(self.v_cache.view(t), (b, t, h, d))

q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d]
k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d]
v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d]

attn_weights = op.matmul(
q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t]
) / math.sqrt(d)
k = self.k_cache.view(t)
v = self.v_cache.view(t)

if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

dtype = attn_weights.dtype
attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask)
if dtype == "float32":
attn_weights = op.softmax(attn_weights, axis=-1)
attn_score_scaling_factor = 1.0 / float(self.layer_idx + 1)
else:
attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype)
# [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d]
output = op.matmul(attn_weights, v)
return self.c_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h * d)))
attn_score_scaling_factor = 1.0
output = op_ext.attention(q, k, v, attention_mask, attn_score_scaling_factor)
return self.c_proj(output)


class GPT2MLP(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_chat/model/llama/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tvm.relax.frontend.nn import Tensor, op

from mlc_chat import op as op_ext
from mlc_chat.nn.kv_cache import FlashInferPagedKVCache, PagedKVCache
from mlc_chat.nn import FlashInferPagedKVCache, PagedKVCache
from mlc_chat.support import logging
from mlc_chat.support import tensor_parallel as tp
from mlc_chat.support.config import ConfigBase
Expand Down Expand Up @@ -342,7 +342,7 @@ def create_flashinfer_paged_kv_cache(
num_kv_heads = self.num_key_value_heads // self.tensor_parallel_shards
# Note: Right now we only have FlashInfer-based KV cache supported.
# TIR version will be introduced soon.
return FlashInferPagedKVCache.create(
return FlashInferPagedKVCache(
max_batch_size=max_batch_size,
max_total_seq_len=max_total_seq_len,
page_size=page_size,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/model/mistral/mistral_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(self, config: MistralConfig):
[MistralDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)]
)
self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)
self.tensor_parallel_shards = config.tensor_parallel_shards > 1
self.tensor_parallel_shards = config.tensor_parallel_shards

def forward( # pylint: disable=too-many-arguments
self,
Expand Down
Empty file.
129 changes: 129 additions & 0 deletions python/mlc_chat/model/mixtral/mixtral_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""
This file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""
import functools

import numpy as np

from mlc_chat.loader import ExternMapping
from mlc_chat.quantization import Quantization

from .mixtral_model import MixtralConfig, MixtralForCasualLM


def huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : MixtralConfig
The configuration of the Mixtral model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = MixtralForCasualLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.num_hidden_layers):
# Add QKV in self attention
attn = f"model.layers.{i}.self_attn"
mlc_name = f"{attn}.qkv_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{attn}.q_proj.weight",
f"{attn}.k_proj.weight",
f"{attn}.v_proj.weight",
],
functools.partial(
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

# Add gates in MLP (when MoE is enabled)
mlp = f"model.layers.{i}.block_sparse_moe"
mlc_mlp = f"model.layers.{i}.moe"
mlc_name = f"{mlc_mlp}.e1_e3.weight"
mlc_param = named_parameters[mlc_name]

def combine_expert_gate_up(*hf_params, dtype):
stack = []
for i in range(0, len(hf_params), 2):
stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))
return np.stack(stack, axis=0).astype(dtype)

mapping.add_mapping(
mlc_name,
functools.reduce(
lambda a, b: a + b,
[
[
f"{mlp}.experts.{expert_id}.w1.weight",
f"{mlp}.experts.{expert_id}.w3.weight",
]
for expert_id in range(model_config.num_local_experts)
],
),
functools.partial(
combine_expert_gate_up,
dtype=mlc_param.dtype,
),
)

mlc_name = f"{mlc_mlp}.e2.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"{mlp}.experts.{expert_id}.w2.weight"
for expert_id in range(model_config.num_local_experts)
],
functools.partial(
lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

mlc_name = f"{mlc_mlp}.gate.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[f"{mlp}.gate.weight"],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

# inv_freq is not used in the model
mapping.add_unused(f"{attn}.rotary_emb.inv_freq")

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading