Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,657 changes: 1,657 additions & 0 deletions PLAN.md

Large diffs are not rendered by default.

454 changes: 454 additions & 0 deletions bench_lloyd_max.py

Large diffs are not rendered by default.

454 changes: 454 additions & 0 deletions bench_lloyd_max_v2.py

Large diffs are not rendered by default.

530 changes: 530 additions & 0 deletions bench_lloydmax_all_tensors.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion exllamav3/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .cache import Cache, CacheLayer
from .fp16 import CacheLayer_fp16
from .quant import CacheLayer_quant
from .recurrent import RecurrentCache, CacheableState
from .tq3 import CacheLayer_tq3
from .recurrent import RecurrentCache, CacheableState
from .lloyd_max import CacheLayer_lloyd_max
148 changes: 148 additions & 0 deletions exllamav3/cache/lloyd_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from __future__ import annotations
from typing_extensions import override
import torch
from ..constants import PAGE_SIZE
from ..model import Config
from .cache import CacheLayer
from typing import TYPE_CHECKING
from exllamav3.ext import exllamav3_ext as ext
if TYPE_CHECKING:
from ..modules import Attention
import numpy as np

class CacheLayer_lloyd_max(CacheLayer):
"""
KV-cache layer using Lloyd-Max optimal quantization.

Identical interface to CacheLayer_quant; the only difference is that
quant/dequant calls go to the Lloyd-Max CUDA kernels (quant_lm_cache_paged /
dequant_lm_cache_paged) which use precomputed Gaussian-optimal codebooks
instead of uniform rounding. The on-disk / in-memory tensor layout is
identical to CacheLayer_quant for the same bit-width.
"""

def __init__(
self,
config: Config | None,
attention: Attention,
cache_id: int,
max_num_tokens: int,
k_bits: int,
v_bits: int,
):
super().__init__(config, attention, cache_id, max_num_tokens)

assert max_num_tokens % PAGE_SIZE == 0, \
f"max_num_tokens must be a multiple of {PAGE_SIZE}."
assert (2 <= k_bits <= 8) and (2 <= v_bits <= 8), "quantized cache must be from 2 to 8 bits"

self.shape = (
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, attention.num_kv_heads, attention.head_dim)
if attention else None
)

self.k_bits = k_bits
self.v_bits = v_bits
self.token_dim = attention.num_kv_heads * attention.head_dim
self.qshape_k = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * k_bits) if attention else None)
self.qshape_v = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * v_bits) if attention else None)
self.qshape_s = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32) if attention else None)

self.qk = None
self.qv = None
self.sk = None
self.sv = None
self.device = None


@override
def alloc(self, device: torch.device):
self.device = device
self.qk = torch.zeros(self.qshape_k, dtype = torch.int, device = device) if self.shape else None
self.qv = torch.zeros(self.qshape_v, dtype = torch.int, device = device) if self.shape else None
self.sk = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
self.sv = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None


@override
def free(self):
self.device = None
self.qk = None
self.qv = None
self.sk = None
self.sv = None


@override
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor):
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
ext.dequant_lm_cache_paged(self.qk, self.sk, k, self.qv, self.sv, v, cache_seqlens, block_table, PAGE_SIZE)
return k, v


@override
def get_kv_alloc_placeholder(self):
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
return k, v


@override
def update_kv(
self,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
ext.quant_lm_cache_paged(
k, self.qk, self.sk,
v, self.qv, self.sv,
cache_seqlens, block_table,
PAGE_SIZE,
length
)


@override
def copy_page(self, source: CacheLayer_lloyd_max, from_page: int, to_page: int, num_tokens: int):
assert self.qshape_k == source.qshape_k
assert self.qshape_v == source.qshape_v
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)


@override
def get_tensors(self):
return [self.qk, self.qv, self.sk, self.sv]


@override
def storage_size(self):
return (
np.prod(self.qshape_k) * torch.int.itemsize +
np.prod(self.qshape_v) * torch.int.itemsize +
2 * np.prod(self.qshape_s) * torch.half.itemsize
)


@override
def overhead_size(self):
return 2 * np.prod(self.shape[2:]) * torch.half.itemsize


@override
def tp_export(self, plan):
return {
"cls": CacheLayer_lloyd_max,
"args": {
"cache_id": self.cache_id,
"max_num_tokens": self.max_num_tokens,
"k_bits": self.k_bits,
"v_bits": self.v_bits,
}
}
155 changes: 155 additions & 0 deletions exllamav3/cache/tq3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations
from typing_extensions import override
import torch
from ..constants import PAGE_SIZE
from ..model import Config
from .cache import CacheLayer
from typing import TYPE_CHECKING
from exllamav3.ext import exllamav3_ext as ext
if TYPE_CHECKING:
from ..modules import Attention
import numpy as np

class CacheLayer_tq3(CacheLayer):
"""
TQ3 quantized KV cache using Lloyd-Max ternary codebook.

Storage layout per 32-element block:
- 2 bitplanes (uint32 each) for ternary encoding = 8 bytes
- 1 fp16 scale = 2 bytes
Total: 10 bytes per 32 values = 2.5 effective bits per value

Compared to CacheLayer_quant at 2 bits:
- Same bitplane count (2 bitplanes per block)
- But uses Lloyd-Max boundaries instead of uniform thresholds
- ~15% lower MSE on Gaussian-distributed data (post-WHT)
"""

def __init__(
self,
config: Config | None,
attention: Attention,
cache_id: int,
max_num_tokens: int,
):
super().__init__(config, attention, cache_id, max_num_tokens)

assert max_num_tokens % PAGE_SIZE == 0, \
f"max_num_tokens must be a multiple of {PAGE_SIZE}."

self.shape = (
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, attention.num_kv_heads, attention.head_dim)
if attention else None
)

# TQ3 uses 2 bitplanes (same storage as 2-bit uniform)
self.bits = 2
self.token_dim = attention.num_kv_heads * attention.head_dim
self.qshape = (
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * self.bits)
if attention else None
)
self.sshape = (
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32)
if attention else None
)

self.qk = None
self.qv = None
self.sk = None
self.sv = None
self.device = None


@override
def alloc(self, device: torch.device):
self.device = device
self.qk = torch.zeros(self.qshape, dtype = torch.int, device = device) if self.shape else None
self.qv = torch.zeros(self.qshape, dtype = torch.int, device = device) if self.shape else None
self.sk = torch.zeros(self.sshape, dtype = torch.half, device = device) if self.shape else None
self.sv = torch.zeros(self.sshape, dtype = torch.half, device = device) if self.shape else None


@override
def free(self):
self.device = None
self.qk = None
self.qv = None
self.sk = None
self.sv = None


@override
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor):
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
ext.dequant_tq3_cache_paged(
self.qk, self.sk, k,
self.qv, self.sv, v,
cache_seqlens, block_table, PAGE_SIZE
)
return k, v


@override
def get_kv_alloc_placeholder(self):
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
return k, v


@override
def update_kv(
self,
cache_seqlens: torch.Tensor,
block_table: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
length: int
):
ext.quant_tq3_cache_paged(
k, self.qk, self.sk,
v, self.qv, self.sv,
cache_seqlens, block_table,
PAGE_SIZE,
length
)


@override
def copy_page(self, source: CacheLayer_tq3, from_page: int, to_page: int, num_tokens: int):
assert self.qshape == source.qshape
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)


@override
def get_tensors(self):
return [self.qk, self.qv, self.sk, self.sv]


@override
def storage_size(self):
return (
np.prod(self.qshape) * torch.int.itemsize +
np.prod(self.qshape) * torch.int.itemsize +
2 * np.prod(self.sshape) * torch.half.itemsize
)


@override
def overhead_size(self):
return 2 * np.prod(self.shape[2:]) * torch.half.itemsize


@override
def tp_export(self, plan):
return {
"cls": CacheLayer_tq3,
"args": {
"cache_id": self.cache_id,
"max_num_tokens": self.max_num_tokens,
}
}
14 changes: 14 additions & 0 deletions exllamav3/exllamav3_ext/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
#include "generator/cache.cuh"

#include "cache/q_cache.cuh"
#include "cache/tq3_cache.cuh"
#include "cache/lm_cache.cuh"
#include "quant/tq3_dequant.cuh"

#include "histogram.cuh"

Expand Down Expand Up @@ -130,6 +133,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("quant_cache_paged", &quant_cache_paged, "quant_cache_paged");
m.def("dequant_cache_paged", &dequant_cache_paged, "dequant_cache_paged");

m.def("quant_tq3_cache_cont", &quant_tq3_cache_cont, "quant_tq3_cache_cont");
m.def("dequant_tq3_cache_cont", &dequant_tq3_cache_cont, "dequant_tq3_cache_cont");
m.def("quant_tq3_cache_paged", &quant_tq3_cache_paged, "quant_tq3_cache_paged");
m.def("dequant_tq3_cache_paged", &dequant_tq3_cache_paged, "dequant_tq3_cache_paged");
m.def("dequant_tq3_weight", &dequant_tq3_weight, "dequant_tq3_weight");

m.def("quant_lm_cache_cont", &quant_lm_cache_cont, "quant_lm_cache_cont");
m.def("dequant_lm_cache_cont", &dequant_lm_cache_cont, "dequant_lm_cache_cont");
m.def("quant_lm_cache_paged", &quant_lm_cache_paged, "quant_lm_cache_paged");
m.def("dequant_lm_cache_paged", &dequant_lm_cache_paged, "dequant_lm_cache_paged");

m.def("count_inf_nan", &count_inf_nan, "count_inf_nan");
m.def("histogram", &histogram, "histogram");

Expand Down
Loading