Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
}
}

const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;

const float d = GGML_E8M0_TO_FP32_HALF(e);

Expand Down
57 changes: 53 additions & 4 deletions gguf-py/gguf/quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
d = max / -8
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q4_0's reference rounding is cursed and depends on FMA
qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)

qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
Expand Down Expand Up @@ -300,8 +299,7 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
d = max / -16
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q5_0's reference rounding is cursed and depends on FMA
q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)

qs = q.reshape((n_blocks, 2, cls.block_size // 2))
qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
Expand Down Expand Up @@ -655,6 +653,57 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
return (d * qs.astype(np.float32))


class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
# e2m1 values (doubled)
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)

@staticmethod
# see ggml_e8m0_to_fp32_half in ggml-impl.h
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
return bits.view(np.float32)

@classmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]

d = abs(blocks).max(axis=-1, keepdims=True)

with np.errstate(divide="ignore"):
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
Comment on lines +673 to +674
Copy link
Collaborator Author

@compilade compilade Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's surprising that the C implementation

const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);

which doesn't check for zero before calling log2f (!) still results in the same number (which is e = 0).

Apparently, that works (checked by ensuring there's some zeroed input blocks in the tests).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log2f(0) returns -inf, which when converted to int turns to zero. However, I don't think this behavior is guaranteed by the C/C++ standard, it may be either undefined or implementation-defined behavior, so it would be better to add a check for zero. Do you want to add it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add it here?

Sure. Done in 2763dc8


d = cls.e8m0_to_fp32_half(e)

kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))

errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
best = np.argmin(errs, axis=-1, keepdims=True)

qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))

qs = qs.reshape((n_blocks, cls.block_size // 2))

return np.concatenate([e, qs], axis=-1)

@classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]

e, qs = np.hsplit(blocks, [1])

d = cls.e8m0_to_fp32_half(e)

qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
qs = (qs & np.uint8(0x0F)).view(np.int8)

kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))

return (d * qs.astype(np.float32))


class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
ksigns: bytes = (
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
Expand Down
19 changes: 14 additions & 5 deletions gguf-py/tests/test_quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self, libggml: Path):
"q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
"q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
"tq1_0", "tq2_0",
"mxfp4",
"iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
"iq4_nl", "iq4_xs",
):
Expand Down Expand Up @@ -140,14 +141,21 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
return False


def do_test(libggml_path: Path, quick: bool = False):
def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
ggml_quants = GGMLQuants(libggml_path)

np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})

r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)

for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
# test zero blocks
r[0, 0, :] = 0
## Maybe test infinities? (can make NANs, not really useful in practice)
# r[0, 1, 0] = np.inf
# r[0, 2, 0] = -np.inf
# r[0, 3, 0] = np.inf
# r[0, 3, 1] = -np.inf

for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
has_dequantize = False
has_quantize = False

Expand Down Expand Up @@ -228,11 +236,12 @@ def do_test(libggml_path: Path, quick: bool = False):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
parser.add_argument("--type", type=str, help="The quant type to test (all by default)")

args = parser.parse_args()

logging.basicConfig(level=logging.DEBUG)

do_test(args.libggml, args.quick)
do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)
Loading