Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
81991fc
oai moe
ngxson Jul 7, 2025
917f923
compat with new checkpoint
ngxson Jul 7, 2025
a4ab869
add attn sink impl
ngxson Jul 7, 2025
3801c36
add rope scaling yarn
ngxson Jul 8, 2025
13f39f6
logits match with latest transformers code
ngxson Jul 8, 2025
b3594b3
wip chat template
ngxson Jul 8, 2025
bd57158
Merge branch 'master' into xsn/oai_moe
ngxson Jul 9, 2025
089a7ab
rm trailing space
ngxson Jul 9, 2025
4d01b36
use ggml_scale_bias
ngxson Jul 9, 2025
f271cc8
Merge branch 'master' into xsn/oai_moe
ngxson Jul 10, 2025
106b17e
rm redundant is_swa_all
ngxson Jul 10, 2025
e2c1beb
convert interleaved gate_up
ngxson Jul 15, 2025
4431c82
Merge remote-tracking branch 'gg-public/master' into xsn/oai_moe-gg
ggerganov Jul 20, 2025
fe9b818
Merge remote-tracking branch 'gg-public/master' into xsn/oai_moe-gg
ggerganov Jul 24, 2025
539c2b6
Merge remote-tracking branch 'gg-public/master' into xsn/oai_moe-gg
ggerganov Jul 29, 2025
039a6f1
graph : fix activation function to match reference (#7)
ggerganov Jul 31, 2025
aa240b9
Merge branch 'master' into xsn/oai_moe-gg
ggerganov Jul 31, 2025
32a654c
Merge branch 'master' into xsn/oai_moe-gg
ggerganov Aug 1, 2025
13f3568
vocab : handle o200k_harmony special tokens
ggerganov Aug 1, 2025
e59b2eb
ggml : add attention sinks support (#1)
ggerganov Aug 1, 2025
832dc26
repack mxfp4 upon conversion
ngxson Aug 1, 2025
c68069d
clean up a bit
ngxson Aug 1, 2025
423b191
enable thinking
ngxson Aug 1, 2025
4dd479b
add quick hack to render only some special tokens
ngxson Aug 1, 2025
ebc7da5
fix bf16 conversion
ngxson Aug 1, 2025
a543ddf
remove vocab hack
ngxson Aug 1, 2025
6b30372
webui ok
ngxson Aug 1, 2025
44bdb75
support chat parsing for gpt-oss
ngxson Aug 1, 2025
65b536f
Merge branch 'master' into xsn/oai_moe
ggerganov Aug 2, 2025
6197917
fix webui
ngxson Aug 2, 2025
3c4725b
direct mapping mxfp4, FINALLY
ngxson Aug 2, 2025
04cfb6d
force using mxfp4
ngxson Aug 2, 2025
4cf69df
properly use lazy tensor
ngxson Aug 3, 2025
ec95c0e
ggml : add mxfp4
ggerganov Jul 20, 2025
3ef6c8c
ggml : add ggml_add_id (#13)
slaren Aug 4, 2025
cd514cc
Merge branch 'master' into xsn/oai_moe
slaren Aug 5, 2025
98c4be5
Merge branch 'xsn/oai_moe' into mxfp4-rebased
slaren Aug 5, 2025
fcb2339
Merge branch 'master' into gpt-oss-mxfp4
ngxson Aug 5, 2025
98f3444
llama : fix compile error
ggerganov Aug 5, 2025
df8411e
cuda : add fallback for __nv_cvt_e8m0_to_bf16raw
slaren Aug 5, 2025
60ab08a
cleanup
slaren Aug 5, 2025
256fe66
sycl : fix supports_op for MXFP4
slaren Aug 5, 2025
cd8ed32
fix Unknown reasoning format
ngxson Aug 5, 2025
a3b291e
ggml-cpu : fix AVX build
slaren Aug 5, 2025
1ea3769
fix hip build
slaren Aug 5, 2025
07d781e
cuda : add mxfp4 dequantization support for cuBLAS
slaren Aug 5, 2025
b236c90
ggml-cpu : fix mxfp4 fallback definitions for some architectures
slaren Aug 5, 2025
d9d89b4
cuda : fix version required for __nv_cvt_e8m0_to_bf16raw
slaren Aug 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2947,11 +2947,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
"(default: deepseek)",
"(default: auto)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
Expand Down
30 changes: 30 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
Expand All @@ -600,6 +601,7 @@ const char * common_chat_format_name(common_chat_format format) {
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
Expand Down Expand Up @@ -1289,6 +1291,26 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}

static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);

data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_GPT_OSS;

// TODO: support tool calls in GPT-OSS?

return data;
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
// TODO @ngxson : this won't work with --special enabled, we should fix that
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
}

static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
Expand Down Expand Up @@ -1772,6 +1794,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_hermes_2_pro(tmpl, params);
}

// GPT-OSS
if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_gpt_oss(tmpl, params);
}

// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
Expand Down Expand Up @@ -1923,6 +1950,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GPT_OSS,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ struct common_params_diffusion {

enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
Expand Down Expand Up @@ -394,7 +395,7 @@ struct common_params {
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response

Expand Down
114 changes: 114 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7950,6 +7950,119 @@ def set_vocab(self):
self.gguf_writer.add_chat_template(chat_template)


@ModelBase.register("GptOssForCausalLM")
class GptOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.GPT_OSS

def transform_nibble_layout(self, tensor):
assert tensor.dtype == torch.uint8
assert tensor.shape[-1] == 16
# swap nibbles
t_lo = tensor & 0x0F
t_hi = tensor & 0xF0
t_swapped = (t_lo << 4) | (t_hi >> 4)
tensor = t_swapped
# transform aaaa...bbbb... to abababab...
blk_a, blk_b = tensor.chunk(2, dim=-1)
# get a_
blk_a0 = (blk_a & 0xF0).view(-1, 1)
blk_a1 = (blk_a << 4).view(-1, 1)
blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
# get _b
blk_b0 = (blk_b >> 4).view(-1, 1)
blk_b1 = (blk_b & 0x0F).view(-1, 1)
blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
# swap once more
out = blk_a | blk_b
out_h = out & 0xF0
out_l = out & 0x0F
out = (out_h >> 4) | (out_l << 4)
return out

def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
assert blocks.dtype == torch.uint8
assert scales.dtype == torch.uint8
scales = scales.unsqueeze(-1)
assert len(blocks.shape) == 4
assert len(scales.shape) == 4
blocks = self.transform_nibble_layout(blocks)
new_data = torch.concat((scales, blocks), dim=-1)
new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
# flatten last dim
new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
new_data = new_data.numpy()
self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)

def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
blocks0: Tensor = torch.zeros(1)
blocks1: Tensor = torch.zeros(1)
found_mxfp4_tensors = False
# we assume that tensors are loaded in the correct order
for name, data_torch in self.get_tensors():
if "mlp.experts.down_proj_blocks" in name:
blocks0 = data_torch
elif "mlp.experts.down_proj_scales" in name:
new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
self.repack_mxfp4(new_name, blocks0, data_torch)
found_mxfp4_tensors = True
elif "mlp.experts.gate_up_proj_blocks" in name:
blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
elif "mlp.experts.gate_up_proj_scales" in name:
scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
Comment on lines +8013 to +8014
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too late, but why was this split? Only adds extra ops on the graph...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gate_up tensor is organized in a way that a row of gate is followed by a row of up, aka interleaving. While we can rearrange it to the expected layout for fused op, I think it's easier to just split it into gate and up independently

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh, didn't catch that.

self.repack_mxfp4(new_name_gate, blocks0, scales0)
self.repack_mxfp4(new_name_up, blocks1, scales1)
found_mxfp4_tensors = True
if not found_mxfp4_tensors:
raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
return []

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if "sinks" in name:
name += ".weight"

# correct naming for down_proj
if "down_proj" in name:
if name.endswith("_bias"):
name = name.replace("down_proj_bias", "down_proj.bias")
else:
return []

# split the gate_up into gate and up
if "gate_up_proj" in name:
if name.endswith("_bias"):
name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
return [
(self.map_tensor_name(name_gate), gate_proj_bias),
(self.map_tensor_name(name_up), up_proj_bias)
]
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def set_vocab(self):
self._set_vocab_gpt2()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])

rope_scaling = self.hparams.get("rope_scaling") or {}
rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))


@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
Expand Down Expand Up @@ -8089,6 +8202,7 @@ class LazyTorchTensor(gguf.LazyBase):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
torch.uint8: np.uint8,
}

# used for safetensors slices
Expand Down
38 changes: 37 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

#define GGML_TENSOR_TERNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
Expand Down Expand Up @@ -395,7 +405,8 @@ extern "C" {
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
GGML_TYPE_COUNT = 39,
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_COUNT = 40,
};

// precision
Expand Down Expand Up @@ -430,6 +441,7 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
};

// available tensor operations:
Expand All @@ -438,6 +450,7 @@ extern "C" {

GGML_OP_DUP,
GGML_OP_ADD,
GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
Expand Down Expand Up @@ -557,6 +570,7 @@ extern "C" {
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
GGML_GLU_OP_SWIGLU_OAI,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,

Expand Down Expand Up @@ -831,6 +845,13 @@ extern "C" {
struct ggml_tensor * b,
enum ggml_type type);

// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
GGML_API struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);

GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down Expand Up @@ -1198,6 +1219,13 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit);

// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1570,6 +1598,10 @@ extern "C" {
float scale,
float max_bias);

GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);

GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
Expand Down Expand Up @@ -2052,6 +2084,10 @@ extern "C" {
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);

GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);

// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
Expand Down
8 changes: 8 additions & 0 deletions ggml/src/ggml-cann/ggml-cann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[2]) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
Expand All @@ -2354,6 +2358,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
// TODO: support attention sinks [TAG_ATTN_SINKS]
if (op->src[4]) {
return false;
}
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
Expand Down
Loading
Loading