Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.system_prompt = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION, LLAMA_EXAMPLE_MTMD}));
add_opt(common_arg(
{"--no-perf"},
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
Expand Down
81 changes: 77 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@
if "llm_config" in config:
# rename for InternVL
config["text_config"] = config["llm_config"]
if "lfm" in config:
# rename for LFM2-Audio
config["text_config"] = config["lfm"]
if "thinker_config" in config:
# rename for Qwen2.5-Omni
config["text_config"] = config["thinker_config"]["text_config"]
Expand Down Expand Up @@ -9616,19 +9619,25 @@
self._add_feed_forward_length()

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
if is_vision_tensor:
# skip vision tensors
if self._is_vision_tensor(name) or self._is_audio_tensor(name):
# skip multimodal tensors
return []

name = name.replace("language_model.", "")
name = name.replace("language_model.", "") # vision
name = name.replace("lfm.", "model.") # audio

# conv op requires 2d tensor
if 'conv.conv' in name:
data_torch = data_torch.squeeze(1)

return [(self.map_tensor_name(name), data_torch)]

def _is_vision_tensor(self, name: str) -> bool:
return "vision_tower" in name or "multi_modal_projector" in name

def _is_audio_tensor(self, name: str):
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])


@ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel):
Expand Down Expand Up @@ -9734,6 +9743,70 @@
return [] # skip other tensors


@ModelBase.register("Lfm2AudioForConditionalGeneration")
class LFM2AudioModel(MmprojModel):
has_vision_encoder = False
has_audio_encoder = True
model_name = "Lfm2AudioEncoder"

_batch_norm_tensors: list[dict[str, Tensor]] | None = None

def get_audio_config(self) -> dict[str, Any] | None:
return self.global_config.get("encoder")

def set_gguf_parameters(self):
self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"]

Check failure on line 9758 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)

Check failure on line 9758 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
self.hparams_audio["intermediate_size"] = self.hparams_audio["d_model"]

Check failure on line 9759 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)

Check failure on line 9759 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
self.hparams_audio["num_attention_heads"] = self.hparams_audio["n_heads"]

Check failure on line 9760 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)

Check failure on line 9760 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LFM2A)
self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["feat_in"])

Check failure on line 9763 in convert_hf_to_gguf.py

View workflow job for this annotation

GitHub Actions / pyright type-check

Object of type "None" is not subscriptable (reportOptionalSubscript)
self.gguf_writer.add_audio_attention_layernorm_eps(1e-5)

def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip language model tensors
if name.startswith("lfm."):
return []

# for training only
if any(p in name for p in ["audio_loss_weight"]):
return []

# for audio output
if any(p in name for p in ["codebook_offsets", "depth_embeddings", "depth_linear", "depthformer"]):
return []

# fold running_mean, running_var and eps into weight and bias for batch_norm
if "batch_norm" in name:
if self._batch_norm_tensors is None:
self._batch_norm_tensors = [{} for _ in range(self.block_count)]
assert bid is not None
self._batch_norm_tensors[bid][name] = data_torch

if len(self._batch_norm_tensors[bid]) < 5:
return []

weight = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.weight"]
bias = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.bias"]
running_mean = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_mean"]
running_var = self._batch_norm_tensors[bid][f"conformer.layers.{bid}.conv.batch_norm.running_var"]
eps = 1e-5 # default value

a = weight / torch.sqrt(running_var + eps)
b = bias - running_mean * a
return [
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.weight"), a),
(self.map_tensor_name(f"conformer.layers.{bid}.conv.batch_norm.bias"), b),
]

return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("SmallThinkerForCausalLM")
class SmallThinkerModel(TextModel):
model_arch = gguf.MODEL_ARCH.SMALLTHINKER
Expand Down
34 changes: 14 additions & 20 deletions ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
const int threads = 128;
GGML_ASSERT(nr % threads == 0);

if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
ssm_conv_f32<threads, 3><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
auto launch_kernel = [&](auto NC) {
constexpr int kNC = decltype(NC)::value;
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
} else {
if (nc == 4) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else if (nc == 3) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 3, split_n_t><<<blocks, threads, 0, stream>>>(
ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 3 or size = 4 right now.");
}
};

switch (nc) {
case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
}
}

Expand Down
43 changes: 43 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,8 @@ class MODEL_TENSOR(IntEnum):
V_TOK_EOI = auto() # cogvlm
# audio (mtmd)
A_ENC_EMBD_POS = auto()
A_ENC_EMBD_NORM = auto()
A_ENC_EMBD_TO_LOGITS = auto()
A_ENC_CONV1D = auto()
A_PRE_NORM = auto()
A_POST_NORM = auto()
Expand All @@ -697,8 +699,13 @@ class MODEL_TENSOR(IntEnum):
A_ENC_OUTPUT = auto()
A_ENC_OUTPUT_NORM = auto()
A_ENC_FFN_UP = auto()
A_ENC_FFN_NORM = auto()
A_ENC_FFN_GATE = auto()
A_ENC_FFN_DOWN = auto()
A_ENC_FFN_UP_1 = auto()
A_ENC_FFN_NORM_1 = auto()
A_ENC_FFN_GATE_1 = auto()
A_ENC_FFN_DOWN_1 = auto()
A_MMPROJ = auto()
A_MMPROJ_FC = auto()
A_MM_NORM_PRE = auto()
Expand All @@ -710,6 +717,12 @@ class MODEL_TENSOR(IntEnum):
NEXTN_HNORM = auto()
NEXTN_SHARED_HEAD_HEAD = auto()
NEXTN_SHARED_HEAD_NORM = auto()
# lfm2 audio
A_ENC_NORM_CONV = auto()
A_ENC_LINEAR_POS = auto()
A_ENC_POS_BIAS_U = auto()
A_ENC_POS_BIAS_V = auto()
A_ENC_OUT = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -1059,6 +1072,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_TOK_EOI: "v.eoi",
# audio (mtmd)
MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd",
MODEL_TENSOR.A_ENC_EMBD_NORM: "a.position_embd_norm",
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS: "a.embd_to_logits",
MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}",
MODEL_TENSOR.A_PRE_NORM: "a.pre_ln",
MODEL_TENSOR.A_POST_NORM: "a.post_ln",
Expand All @@ -1068,9 +1083,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1",
MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out",
MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2",
MODEL_TENSOR.A_ENC_FFN_NORM: "a.blk.{bid}.ffn_norm",
MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up",
MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate",
MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down",
MODEL_TENSOR.A_ENC_FFN_NORM_1: "a.blk.{bid}.ffn_norm_1",
MODEL_TENSOR.A_ENC_FFN_UP_1: "a.blk.{bid}.ffn_up_1",
MODEL_TENSOR.A_ENC_FFN_GATE_1: "a.blk.{bid}.ffn_gate_1",
MODEL_TENSOR.A_ENC_FFN_DOWN_1: "a.blk.{bid}.ffn_down_1",
MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}",
MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc",
MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre",
Expand All @@ -1082,6 +1102,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm",
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head",
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm",
# lfm2
MODEL_TENSOR.A_ENC_NORM_CONV: "a.blk.{bid}.norm_conv",
MODEL_TENSOR.A_ENC_LINEAR_POS: "a.blk.{bid}.linear_pos",
MODEL_TENSOR.A_ENC_POS_BIAS_U: "a.blk.{bid}.pos_bias_u",
MODEL_TENSOR.A_ENC_POS_BIAS_V: "a.blk.{bid}.pos_bias_v",
MODEL_TENSOR.A_ENC_OUT: "a.pre_encode.out",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -1137,6 +1163,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.V_TOK_EOI,
# audio
MODEL_TENSOR.A_ENC_EMBD_POS,
MODEL_TENSOR.A_ENC_EMBD_NORM,
MODEL_TENSOR.A_ENC_EMBD_TO_LOGITS,
MODEL_TENSOR.A_ENC_CONV1D,
MODEL_TENSOR.A_PRE_NORM,
MODEL_TENSOR.A_POST_NORM,
Expand All @@ -1146,13 +1174,27 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_ENC_INPUT_NORM,
MODEL_TENSOR.A_ENC_OUTPUT,
MODEL_TENSOR.A_ENC_OUTPUT_NORM,
MODEL_TENSOR.A_ENC_FFN_NORM,
MODEL_TENSOR.A_ENC_FFN_UP,
MODEL_TENSOR.A_ENC_FFN_GATE,
MODEL_TENSOR.A_ENC_FFN_DOWN,
MODEL_TENSOR.A_ENC_FFN_NORM_1,
MODEL_TENSOR.A_ENC_FFN_UP_1,
MODEL_TENSOR.A_ENC_FFN_GATE_1,
MODEL_TENSOR.A_ENC_FFN_DOWN_1,
MODEL_TENSOR.A_MMPROJ,
MODEL_TENSOR.A_MMPROJ_FC,
MODEL_TENSOR.A_MM_NORM_PRE,
MODEL_TENSOR.A_MM_NORM_MID,
MODEL_TENSOR.CONVNEXT_DW,
MODEL_TENSOR.CONVNEXT_NORM,
MODEL_TENSOR.CONVNEXT_PW1,
MODEL_TENSOR.CONVNEXT_PW2,
MODEL_TENSOR.A_ENC_NORM_CONV,
MODEL_TENSOR.A_ENC_LINEAR_POS,
MODEL_TENSOR.A_ENC_POS_BIAS_U,
MODEL_TENSOR.A_ENC_POS_BIAS_V,
MODEL_TENSOR.A_ENC_OUT,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down Expand Up @@ -3327,6 +3369,7 @@ class VisionProjectorType:
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
LFM2A = "lfm2a" # audio


# Items here are (block size, type size)
Expand Down
Loading
Loading