Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,39 @@ def test_solve_dtype_float32(self):
assert centroids.dtype == torch.float32
assert boundaries.dtype == torch.float32

# ---- Pre-baked centroid tables ----

@pytest.mark.parametrize(
"d,bits",
[(d, b) for d in (64, 128, 256) for b in (3, 4, 8)],
)
def test_prebaked_matches_solver(self, d, bits):
"""Pre-baked centroid tables must equal the runtime solver output.

get_centroids() short-circuits to a pre-baked table for the
(d, bits) pairs the TQ+ presets use. Lloyd-Max is fully
deterministic given d and bits, so the pre-baked values must be
bit-precise equal to what the solver produces.
"""
prebaked = get_centroids(d, bits)
solver, _ = solve_lloyd_max(d, bits)
assert prebaked.shape == solver.shape
assert prebaked.dtype == solver.dtype
# Bit-precise equality: the pre-baked table was emitted from this
# same solver, so any drift indicates a stale paste.
max_abs_diff = (prebaked - solver).abs().max().item()
assert max_abs_diff == 0.0, (
f"Pre-baked (d={d}, bits={bits}) differs from solver: "
f"max|Δ|={max_abs_diff:.3e}"
)

def test_get_centroids_falls_back_for_unbaked_shape(self):
"""Shapes outside the pre-baked table fall back to the solver."""
# d=192 is non-standard and intentionally outside the pre-baked table
centroids = get_centroids(192, 3)
solver, _ = solve_lloyd_max(192, 3)
assert torch.equal(centroids, solver)

@pytest.mark.parametrize("bits", [3, 4])
def test_centroids_match_scipy_reference(self, bits):
"""Verify _trapz(n=200) centroids match scipy.integrate.quad reference.
Expand Down
128 changes: 127 additions & 1 deletion vllm/model_executor/layers/quantization/turboquant/centroids.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,134 @@ def pdf(x):
)


# Pre-baked centroids for the (d, bits) pairs the TQ+ presets actually use.
# Short-circuits the 50 ms (3-bit) / 160 ms (4-bit) / 2.5 s (8-bit) Lloyd-Max
# iteration on the first request per shape. Values were produced by
# solve_lloyd_max(d, bits) below; that function is fully deterministic given
# d and bits, so to regenerate just print solve_lloyd_max(d, bits)[0].tolist()
# and paste back here. Verified equal to the runtime solver to bit precision.
_PREBAKED_CENTROIDS: dict[tuple[int, int], tuple[float, ...]] = {
(64, 3): (-2.6879370213e-01, -1.6789750755e-01, -9.4455689192e-02, -3.0622949824e-02, +3.0622949824e-02, +9.4455689192e-02, +1.6789750755e-01, +2.6879370213e-01),
(64, 4): (-3.4136527777e-01, -2.5855588913e-01, -2.0223522186e-01, -1.5703219175e-01, -1.1780603230e-01, -8.2109987736e-02, -4.8517223448e-02, -1.6053456813e-02, +1.6053456813e-02, +4.8517223448e-02, +8.2109987736e-02, +1.1780603230e-01, +1.5703219175e-01, +2.0223522186e-01, +2.5855588913e-01, +3.4136527777e-01),
(64, 8): (
-6.0041362047e-01, -5.5067229271e-01, -5.1984894276e-01, -4.9724340439e-01, -4.7935143113e-01, -4.6456301212e-01, -4.5199537277e-01, -4.4110643864e-01,
-4.3153521419e-01, -4.2302605510e-01, -4.1538882256e-01, -4.0847626328e-01, -4.0217083693e-01, -3.9637619257e-01, -3.9101192355e-01, -3.8600999117e-01,
-3.8131225109e-01, -3.7686887383e-01, -3.7263703346e-01, -3.6857995391e-01, -3.6466625333e-01, -3.6086925864e-01, -3.5716637969e-01, -3.5353884101e-01,
-3.4997096658e-01, -3.4644994140e-01, -3.4296530485e-01, -3.3950868249e-01, -3.3607342839e-01, -3.3265423775e-01, -3.2924708724e-01, -3.2584878802e-01,
-3.2245698571e-01, -3.1906986237e-01, -3.1568610668e-01, -3.1230473518e-01, -3.0892503262e-01, -3.0554649234e-01, -3.0216875672e-01, -2.9879152775e-01,
-2.9541468620e-01, -2.9203808308e-01, -2.8866165876e-01, -2.8528529406e-01, -2.8190904856e-01, -2.7853280306e-01, -2.7515661716e-01, -2.7178043127e-01,
-2.6840424538e-01, -2.6502808928e-01, -2.6165190339e-01, -2.5827574730e-01, -2.5489959121e-01, -2.5152343512e-01, -2.4814727902e-01, -2.4477112293e-01,
-2.4139496684e-01, -2.3801881075e-01, -2.3464265466e-01, -2.3126648366e-01, -2.2789032757e-01, -2.2451417148e-01, -2.2113801539e-01, -2.1776185930e-01,
-2.1438570321e-01, -2.1100954711e-01, -2.0763340592e-01, -2.0425724983e-01, -2.0088109374e-01, -1.9750493765e-01, -1.9412878156e-01, -1.9075262547e-01,
-1.8737646937e-01, -1.8400031328e-01, -1.8062415719e-01, -1.7724800110e-01, -1.7387184501e-01, -1.7049568892e-01, -1.6711954772e-01, -1.6374339163e-01,
-1.6036723554e-01, -1.5699107945e-01, -1.5361492336e-01, -1.5023876727e-01, -1.4686261117e-01, -1.4348646998e-01, -1.4011031389e-01, -1.3673415780e-01,
-1.3335800171e-01, -1.2998184562e-01, -1.2660570443e-01, -1.2322954834e-01, -1.1985339224e-01, -1.1647723615e-01, -1.1310108751e-01, -1.0972493142e-01,
-1.0634878278e-01, -1.0297262669e-01, -9.9596478045e-02, -9.6220321953e-02, -9.2844173312e-02, -8.9468017220e-02, -8.6091868579e-02, -8.2715712488e-02,
-7.9339563847e-02, -7.5963415205e-02, -7.2587259114e-02, -6.9211110473e-02, -6.5834954381e-02, -6.2458805740e-02, -5.9082653373e-02, -5.5706501007e-02,
-5.2330348641e-02, -4.8954196274e-02, -4.5578043908e-02, -4.2201895267e-02, -3.8825742900e-02, -3.5449590534e-02, -3.2073438168e-02, -2.8697287664e-02,
-2.5321135297e-02, -2.1944984794e-02, -1.8568832427e-02, -1.5192681924e-02, -1.1816530488e-02, -8.4403790534e-03, -5.0642271526e-03, -1.6880757175e-03,
+1.6880757175e-03, +5.0642271526e-03, +8.4403790534e-03, +1.1816530488e-02, +1.5192681924e-02, +1.8568832427e-02, +2.1944984794e-02, +2.5321135297e-02,
+2.8697287664e-02, +3.2073438168e-02, +3.5449590534e-02, +3.8825742900e-02, +4.2201895267e-02, +4.5578043908e-02, +4.8954196274e-02, +5.2330348641e-02,
+5.5706501007e-02, +5.9082653373e-02, +6.2458805740e-02, +6.5834954381e-02, +6.9211110473e-02, +7.2587259114e-02, +7.5963415205e-02, +7.9339563847e-02,
+8.2715712488e-02, +8.6091868579e-02, +8.9468017220e-02, +9.2844173312e-02, +9.6220321953e-02, +9.9596478045e-02, +1.0297262669e-01, +1.0634878278e-01,
+1.0972493142e-01, +1.1310108751e-01, +1.1647723615e-01, +1.1985339224e-01, +1.2322954834e-01, +1.2660570443e-01, +1.2998184562e-01, +1.3335800171e-01,
+1.3673415780e-01, +1.4011031389e-01, +1.4348646998e-01, +1.4686261117e-01, +1.5023876727e-01, +1.5361492336e-01, +1.5699107945e-01, +1.6036723554e-01,
+1.6374339163e-01, +1.6711954772e-01, +1.7049568892e-01, +1.7387184501e-01, +1.7724800110e-01, +1.8062415719e-01, +1.8400031328e-01, +1.8737646937e-01,
+1.9075262547e-01, +1.9412878156e-01, +1.9750493765e-01, +2.0088109374e-01, +2.0425724983e-01, +2.0763340592e-01, +2.1100954711e-01, +2.1438570321e-01,
+2.1776185930e-01, +2.2113801539e-01, +2.2451417148e-01, +2.2789032757e-01, +2.3126648366e-01, +2.3464265466e-01, +2.3801881075e-01, +2.4139496684e-01,
+2.4477112293e-01, +2.4814727902e-01, +2.5152343512e-01, +2.5489959121e-01, +2.5827574730e-01, +2.6165190339e-01, +2.6502808928e-01, +2.6840424538e-01,
+2.7178043127e-01, +2.7515661716e-01, +2.7853280306e-01, +2.8190904856e-01, +2.8528529406e-01, +2.8866165876e-01, +2.9203808308e-01, +2.9541468620e-01,
+2.9879152775e-01, +3.0216875672e-01, +3.0554649234e-01, +3.0892503262e-01, +3.1230473518e-01, +3.1568610668e-01, +3.1906986237e-01, +3.2245698571e-01,
+3.2584878802e-01, +3.2924708724e-01, +3.3265423775e-01, +3.3607342839e-01, +3.3950868249e-01, +3.4296530485e-01, +3.4644994140e-01, +3.4997096658e-01,
+3.5353884101e-01, +3.5716637969e-01, +3.6086925864e-01, +3.6466625333e-01, +3.6857995391e-01, +3.7263703346e-01, +3.7686887383e-01, +3.8131225109e-01,
+3.8600999117e-01, +3.9101192355e-01, +3.9637619257e-01, +4.0217083693e-01, +4.0847626328e-01, +4.1538882256e-01, +4.2302605510e-01, +4.3153521419e-01,
+4.4110643864e-01, +4.5199537277e-01, +4.6456301212e-01, +4.7935143113e-01, +4.9724340439e-01, +5.1984894276e-01, +5.5067229271e-01, +6.0041362047e-01,
),
(128, 3): (-1.9006584585e-01, -1.1872146279e-01, -6.6790260375e-02, -2.1653695032e-02, +2.1653695032e-02, +6.6790260375e-02, +1.1872146279e-01, +1.9006584585e-01),
(128, 4): (-2.4138170481e-01, -1.8282662332e-01, -1.4300189912e-01, -1.1103852838e-01, -8.3301439881e-02, -5.8060530573e-02, -3.4306857735e-02, -1.1351509020e-02, +1.1351509020e-02, +3.4306857735e-02, +5.8060530573e-02, +8.3301439881e-02, +1.1103852838e-01, +1.4300189912e-01, +1.8282662332e-01, +2.4138170481e-01),
(128, 8): (
-4.2455655336e-01, -3.8938412070e-01, -3.6758869886e-01, -3.5160419345e-01, -3.3895263076e-01, -3.2849565148e-01, -3.1960898638e-01, -3.1190934777e-01,
-3.0514147878e-01, -2.9912459850e-01, -2.9372423887e-01, -2.8883633018e-01, -2.8437772393e-01, -2.8028029203e-01, -2.7648717165e-01, -2.7295026183e-01,
-2.6962846518e-01, -2.6648652554e-01, -2.6349416375e-01, -2.6062539220e-01, -2.5785797834e-01, -2.5517308712e-01, -2.5255477428e-01, -2.4998971820e-01,
-2.4746684730e-01, -2.4497710168e-01, -2.4251309037e-01, -2.4006889760e-01, -2.3763979971e-01, -2.3522207141e-01, -2.3281283677e-01, -2.3040989041e-01,
-2.2801151872e-01, -2.2561646998e-01, -2.2322379053e-01, -2.2083279490e-01, -2.1844299138e-01, -2.1605399251e-01, -2.1366555989e-01, -2.1127751470e-01,
-2.0888973773e-01, -2.0650210977e-01, -2.0411460102e-01, -2.0172718167e-01, -1.9933979213e-01, -1.9695243239e-01, -1.9456510246e-01, -1.9217777252e-01,
-1.8979045749e-01, -1.8740315735e-01, -1.8501584232e-01, -1.8262854218e-01, -1.8024122715e-01, -1.7785392702e-01, -1.7546662688e-01, -1.7307931185e-01,
-1.7069201171e-01, -1.6830471158e-01, -1.6591741145e-01, -1.6353011131e-01, -1.6114279628e-01, -1.5875549614e-01, -1.5636819601e-01, -1.5398089588e-01,
-1.5159359574e-01, -1.4920628071e-01, -1.4681898057e-01, -1.4443168044e-01, -1.4204438031e-01, -1.3965708017e-01, -1.3726978004e-01, -1.3488247991e-01,
-1.3249516487e-01, -1.3010786474e-01, -1.2772056460e-01, -1.2533326447e-01, -1.2294596434e-01, -1.2055866420e-01, -1.1817136407e-01, -1.1578405648e-01,
-1.1339675635e-01, -1.1100945622e-01, -1.0862215608e-01, -1.0623485595e-01, -1.0384755582e-01, -1.0146025568e-01, -9.9072948098e-02, -9.6685647964e-02,
-9.4298347831e-02, -9.1911047697e-02, -8.9523747563e-02, -8.7136447430e-02, -8.4749147296e-02, -8.2361847162e-02, -7.9974547029e-02, -7.7587246895e-02,
-7.5199946761e-02, -7.2812646627e-02, -7.0425346494e-02, -6.8038046360e-02, -6.5650746226e-02, -6.3263446093e-02, -6.0876142234e-02, -5.8488842100e-02,
-5.6101541966e-02, -5.3714241832e-02, -5.1326941699e-02, -4.8939645290e-02, -4.6552345157e-02, -4.4165045023e-02, -4.1777744889e-02, -3.9390444756e-02,
-3.7003144622e-02, -3.4615844488e-02, -3.2228544354e-02, -2.9841246083e-02, -2.7453945950e-02, -2.5066645816e-02, -2.2679345682e-02, -2.0292047411e-02,
-1.7904747277e-02, -1.5517447144e-02, -1.3130147941e-02, -1.0742847808e-02, -8.3555486053e-03, -5.9682489373e-03, -3.5809492692e-03, -1.1936498340e-03,
+1.1936498340e-03, +3.5809492692e-03, +5.9682489373e-03, +8.3555486053e-03, +1.0742847808e-02, +1.3130147941e-02, +1.5517447144e-02, +1.7904747277e-02,
+2.0292047411e-02, +2.2679345682e-02, +2.5066645816e-02, +2.7453945950e-02, +2.9841246083e-02, +3.2228544354e-02, +3.4615844488e-02, +3.7003144622e-02,
+3.9390444756e-02, +4.1777744889e-02, +4.4165045023e-02, +4.6552345157e-02, +4.8939645290e-02, +5.1326941699e-02, +5.3714241832e-02, +5.6101541966e-02,
+5.8488842100e-02, +6.0876142234e-02, +6.3263446093e-02, +6.5650746226e-02, +6.8038046360e-02, +7.0425346494e-02, +7.2812646627e-02, +7.5199946761e-02,
+7.7587246895e-02, +7.9974547029e-02, +8.2361847162e-02, +8.4749147296e-02, +8.7136447430e-02, +8.9523747563e-02, +9.1911047697e-02, +9.4298347831e-02,
+9.6685647964e-02, +9.9072948098e-02, +1.0146025568e-01, +1.0384755582e-01, +1.0623485595e-01, +1.0862215608e-01, +1.1100945622e-01, +1.1339675635e-01,
+1.1578405648e-01, +1.1817136407e-01, +1.2055866420e-01, +1.2294596434e-01, +1.2533326447e-01, +1.2772056460e-01, +1.3010786474e-01, +1.3249516487e-01,
+1.3488247991e-01, +1.3726978004e-01, +1.3965708017e-01, +1.4204438031e-01, +1.4443168044e-01, +1.4681898057e-01, +1.4920628071e-01, +1.5159359574e-01,
+1.5398089588e-01, +1.5636819601e-01, +1.5875549614e-01, +1.6114279628e-01, +1.6353011131e-01, +1.6591741145e-01, +1.6830471158e-01, +1.7069201171e-01,
+1.7307931185e-01, +1.7546662688e-01, +1.7785392702e-01, +1.8024122715e-01, +1.8262854218e-01, +1.8501584232e-01, +1.8740315735e-01, +1.8979045749e-01,
+1.9217777252e-01, +1.9456510246e-01, +1.9695243239e-01, +1.9933979213e-01, +2.0172718167e-01, +2.0411460102e-01, +2.0650210977e-01, +2.0888973773e-01,
+2.1127751470e-01, +2.1366555989e-01, +2.1605399251e-01, +2.1844299138e-01, +2.2083279490e-01, +2.2322379053e-01, +2.2561646998e-01, +2.2801151872e-01,
+2.3040989041e-01, +2.3281283677e-01, +2.3522207141e-01, +2.3763979971e-01, +2.4006889760e-01, +2.4251309037e-01, +2.4497710168e-01, +2.4746684730e-01,
+2.4998971820e-01, +2.5255477428e-01, +2.5517308712e-01, +2.5785797834e-01, +2.6062539220e-01, +2.6349416375e-01, +2.6648652554e-01, +2.6962846518e-01,
+2.7295026183e-01, +2.7648717165e-01, +2.8028029203e-01, +2.8437772393e-01, +2.8883633018e-01, +2.9372423887e-01, +2.9912459850e-01, +3.0514147878e-01,
+3.1190934777e-01, +3.1960898638e-01, +3.2849565148e-01, +3.3895263076e-01, +3.5160419345e-01, +3.6758869886e-01, +3.8938412070e-01, +4.2455655336e-01,
),
(256, 3): (-1.3439685106e-01, -8.3948753774e-02, -4.7227844596e-02, -1.5311474912e-02, +1.5311474912e-02, +4.7227844596e-02, +8.3948753774e-02, +1.3439685106e-01),
(256, 4): (-1.7068263888e-01, -1.2927794456e-01, -1.0111761093e-01, -7.8516095877e-02, -5.8903016150e-02, -4.1054993868e-02, -2.4258611724e-02, -8.0267284065e-03, +8.0267284065e-03, +2.4258611724e-02, +4.1054993868e-02, +5.8903016150e-02, +7.8516095877e-02, +1.0111761093e-01, +1.2927794456e-01, +1.7068263888e-01),
(256, 8): (
-3.0020681024e-01, -2.7533614635e-01, -2.5992447138e-01, -2.4862170219e-01, -2.3967571557e-01, -2.3228150606e-01, -2.2599768639e-01, -2.2055321932e-01,
-2.1576760709e-01, -2.1151302755e-01, -2.0769441128e-01, -2.0423813164e-01, -2.0108541846e-01, -1.9818809628e-01, -1.9550596178e-01, -1.9300499558e-01,
-1.9065612555e-01, -1.8843443692e-01, -1.8631851673e-01, -1.8428997695e-01, -1.8233312666e-01, -1.8043462932e-01, -1.7858318985e-01, -1.7676942050e-01,
-1.7498548329e-01, -1.7322497070e-01, -1.7148265243e-01, -1.6975434124e-01, -1.6803671420e-01, -1.6632711887e-01, -1.6462354362e-01, -1.6292439401e-01,
-1.6122849286e-01, -1.5953493118e-01, -1.5784305334e-01, -1.5615236759e-01, -1.5446251631e-01, -1.5277324617e-01, -1.5108437836e-01, -1.4939576387e-01,
-1.4770734310e-01, -1.4601904154e-01, -1.4433082938e-01, -1.4264264703e-01, -1.4095452428e-01, -1.3926640153e-01, -1.3757830858e-01, -1.3589021564e-01,
-1.3420212269e-01, -1.3251404464e-01, -1.3082595170e-01, -1.2913787365e-01, -1.2744979560e-01, -1.2576171756e-01, -1.2407363951e-01, -1.2238556147e-01,
-1.2069748342e-01, -1.1900940537e-01, -1.1732132733e-01, -1.1563324183e-01, -1.1394516379e-01, -1.1225708574e-01, -1.1056900769e-01, -1.0888092965e-01,
-1.0719285160e-01, -1.0550477356e-01, -1.0381670296e-01, -1.0212862492e-01, -1.0044054687e-01, -9.8752468824e-02, -9.7064390779e-02, -9.5376312733e-02,
-9.3688234687e-02, -9.2000156641e-02, -9.0312078595e-02, -8.8624000549e-02, -8.6935922503e-02, -8.5247844458e-02, -8.3559773862e-02, -8.1871695817e-02,
-8.0183617771e-02, -7.8495539725e-02, -7.6807461679e-02, -7.5119383633e-02, -7.3431305587e-02, -7.1743234992e-02, -7.0055156946e-02, -6.8367078900e-02,
-6.6679000854e-02, -6.4990922809e-02, -6.3302852213e-02, -6.1614774168e-02, -5.9926696122e-02, -5.8238618076e-02, -5.6550543755e-02, -5.4862465709e-02,
-5.3174391389e-02, -5.1486313343e-02, -4.9798239022e-02, -4.8110160977e-02, -4.6422086656e-02, -4.4734008610e-02, -4.3045934290e-02, -4.1357856244e-02,
-3.9669781923e-02, -3.7981707603e-02, -3.6293629557e-02, -3.4605555236e-02, -3.2917477190e-02, -3.1229402870e-02, -2.9541326687e-02, -2.7853250504e-02,
-2.6165174320e-02, -2.4477098137e-02, -2.2789021954e-02, -2.1100947633e-02, -1.9412871450e-02, -1.7724795267e-02, -1.6036719084e-02, -1.4348643832e-02,
-1.2660567649e-02, -1.0972492397e-02, -9.2844162136e-03, -7.5963409618e-03, -5.9082652442e-03, -4.2201895267e-03, -2.5321135763e-03, -8.4403785877e-04,
+8.4403785877e-04, +2.5321135763e-03, +4.2201895267e-03, +5.9082652442e-03, +7.5963409618e-03, +9.2844162136e-03, +1.0972492397e-02, +1.2660567649e-02,
+1.4348643832e-02, +1.6036719084e-02, +1.7724795267e-02, +1.9412871450e-02, +2.1100947633e-02, +2.2789021954e-02, +2.4477098137e-02, +2.6165174320e-02,
+2.7853250504e-02, +2.9541326687e-02, +3.1229402870e-02, +3.2917477190e-02, +3.4605555236e-02, +3.6293629557e-02, +3.7981707603e-02, +3.9669781923e-02,
+4.1357856244e-02, +4.3045934290e-02, +4.4734008610e-02, +4.6422086656e-02, +4.8110160977e-02, +4.9798239022e-02, +5.1486313343e-02, +5.3174391389e-02,
+5.4862465709e-02, +5.6550543755e-02, +5.8238618076e-02, +5.9926696122e-02, +6.1614774168e-02, +6.3302852213e-02, +6.4990922809e-02, +6.6679000854e-02,
+6.8367078900e-02, +7.0055156946e-02, +7.1743234992e-02, +7.3431305587e-02, +7.5119383633e-02, +7.6807461679e-02, +7.8495539725e-02, +8.0183617771e-02,
+8.1871695817e-02, +8.3559773862e-02, +8.5247844458e-02, +8.6935922503e-02, +8.8624000549e-02, +9.0312078595e-02, +9.2000156641e-02, +9.3688234687e-02,
+9.5376312733e-02, +9.7064390779e-02, +9.8752468824e-02, +1.0044054687e-01, +1.0212862492e-01, +1.0381670296e-01, +1.0550477356e-01, +1.0719285160e-01,
+1.0888092965e-01, +1.1056900769e-01, +1.1225708574e-01, +1.1394516379e-01, +1.1563324183e-01, +1.1732132733e-01, +1.1900940537e-01, +1.2069748342e-01,
+1.2238556147e-01, +1.2407363951e-01, +1.2576171756e-01, +1.2744979560e-01, +1.2913787365e-01, +1.3082595170e-01, +1.3251404464e-01, +1.3420212269e-01,
+1.3589021564e-01, +1.3757830858e-01, +1.3926640153e-01, +1.4095452428e-01, +1.4264264703e-01, +1.4433082938e-01, +1.4601904154e-01, +1.4770734310e-01,
+1.4939576387e-01, +1.5108437836e-01, +1.5277324617e-01, +1.5446251631e-01, +1.5615236759e-01, +1.5784305334e-01, +1.5953493118e-01, +1.6122849286e-01,
+1.6292439401e-01, +1.6462354362e-01, +1.6632711887e-01, +1.6803671420e-01, +1.6975434124e-01, +1.7148265243e-01, +1.7322497070e-01, +1.7498548329e-01,
+1.7676942050e-01, +1.7858318985e-01, +1.8043462932e-01, +1.8233312666e-01, +1.8428997695e-01, +1.8631851673e-01, +1.8843443692e-01, +1.9065612555e-01,
+1.9300499558e-01, +1.9550596178e-01, +1.9818809628e-01, +2.0108541846e-01, +2.0423813164e-01, +2.0769441128e-01, +2.1151302755e-01, +2.1576760709e-01,
+2.2055321932e-01, +2.2599768639e-01, +2.3228150606e-01, +2.3967571557e-01, +2.4862170219e-01, +2.5992447138e-01, +2.7533614635e-01, +3.0020681024e-01,
),
Comment thread
TheTom marked this conversation as resolved.
}


@lru_cache(maxsize=32)
def get_centroids(d: int, bits: int) -> torch.Tensor:
"""Get precomputed Lloyd-Max centroids (cached)."""
"""Get precomputed Lloyd-Max centroids (cached).

Short-circuits to a pre-baked table for known (d, bits) pairs (the common
head-dim × bit-width combinations used by the TQ+ presets). Falls back to
runtime Lloyd-Max for any other shape.
"""
prebaked = _PREBAKED_CENTROIDS.get((d, bits))
if prebaked is not None:
return torch.tensor(prebaked, dtype=torch.float32)
centroids, _ = solve_lloyd_max(d, bits)
return centroids
Loading