diff --git a/exllamav3/exllamav3_ext/bindings.cpp b/exllamav3/exllamav3_ext/bindings.cpp index b0a55c97..0fa3dd1d 100644 --- a/exllamav3/exllamav3_ext/bindings.cpp +++ b/exllamav3/exllamav3_ext/bindings.cpp @@ -44,6 +44,7 @@ #include "parallel/barrier.cuh" #include "parallel/gather.cuh" #include "parallel/all_reduce.cuh" +#include "parallel/tq3_all_reduce.cuh" #include "libtorch/gated_delta_net.h" #include "libtorch/linear.h" @@ -78,6 +79,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("pg_gather", &pg_gather, "pg_gather"); m.def("pg_all_reduce", &pg_all_reduce, "pg_all_reduce"); m.def("pg_all_reduce_cpu", &pg_all_reduce_cpu, "pg_all_reduce_cpu"); + m.def("tq3_all_reduce", &tq3_all_reduce, "tq3_all_reduce"); m.def("run_cpu_reduce_jobs", &run_cpu_reduce_jobs, "run_cpu_reduce_jobs"); m.def("end_cpu_reduce_jobs", &end_cpu_reduce_jobs, "end_cpu_reduce_jobs"); diff --git a/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cu b/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cu new file mode 100644 index 00000000..d3f8e73c --- /dev/null +++ b/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cu @@ -0,0 +1,258 @@ +#include +#include "tq3_all_reduce.cuh" +#include "tq3_compress.cuh" +#include +#include +#include +namespace cg = cooperative_groups; +#include "../util.h" +#include "../util.cuh" +#include "../ptx.cuh" +#include "context.cuh" +#include "timeout.cuh" +#include "barrier_inner.cuh" + +// --------------------------------------------------------------------------- +// Shared buffer layout for TQ3 all-reduce (all-gather + local-sum pattern) +// +// The pinned shared buffer (shm_b, 16 MB by default) is sliced into +// num_ranks equal slots. Each slot holds the TQ3-compressed representation +// of one rank's fp16 tensor: +// +// slot_bytes = num_tq3_blocks * 10 +// 10 bytes per block = 4 (bp0) + 4 (bp1) + 2 (scale, fp16) +// +// Slot for rank r starts at: shbuf_ptr + r * slot_bytes +// +// Algorithm +// Phase 1 — every thread processes one or more TQ3 blocks from its rank's +// local fp16 data and stores compressed output to shbuf[this_rank]. +// Phase 2 — __threadfence_system() + cooperative grid.sync() + barrier_inner +// ensures all GPUs see each other's writes. +// Phase 3 — every thread accumulates ALL ranks' compressed data for its +// assigned blocks into a float accumulator, then stores fp16 result +// back to the data tensor in-place. +// Phase 4 — second barrier to make results visible before the kernel exits. +// --------------------------------------------------------------------------- + +#define TQ3_AR_MAX_THREADS 1024 + +// --------------------------------------------------------------------------- +// Kernel +// --------------------------------------------------------------------------- +__global__ __launch_bounds__(TQ3_AR_MAX_THREADS) +void tq3_all_reduce_kernel +( + PGContext* __restrict__ ctx, + const uint32_t device_mask, + int this_device, + int master_device, + half* __restrict__ data_ptr, // fp16 input/output on this GPU + uint8_t* __restrict__ shbuf_ptr, // pinned shared ring buffer + const size_t num_elements, // number of fp16 elements + const size_t slot_bytes, // bytes per rank slot in shbuf + bool contribution, // false → treat local data as zeros + uint32_t* __restrict__ abort_flag +) +{ + auto grid = cg::this_grid(); + + const int num_ranks = __popc(device_mask); + const int this_rank = __popc(device_mask & ((1 << this_device) - 1)); + + const size_t num_blocks = (num_elements + TQ3_BLOCK_SIZE - 1) / TQ3_BLOCK_SIZE; + const size_t block_bytes = 10u; // bp0(4) + bp1(4) + scale(2) + + // Pointer to this rank's write slot + uint8_t* my_slot = shbuf_ptr + (size_t)this_rank * slot_bytes; + + // ------------------------------------------------------------------ + // Phase 1: compress local fp16 data → TQ3, write to our shbuf slot + // ------------------------------------------------------------------ + { + int global_tid = (int)(blockIdx.x * blockDim.x + threadIdx.x); + int stride = (int)(gridDim.x * blockDim.x); + + for (size_t blk = (size_t)global_tid; blk < num_blocks; blk += (size_t)stride) + { + size_t elem_off = blk * TQ3_BLOCK_SIZE; + size_t elems_this_block = min((size_t)TQ3_BLOCK_SIZE, num_elements - elem_off); + + half src_buf[TQ3_BLOCK_SIZE]; + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + if ((size_t)i < elems_this_block && contribution) + src_buf[i] = data_ptr[elem_off + (size_t)i]; + else + src_buf[i] = __float2half(0.0f); + } + + uint32_t bp0, bp1; + half scale; + tq3_compress_block(src_buf, &bp0, &bp1, &scale); + + uint8_t* dst = my_slot + blk * block_bytes; + *reinterpret_cast(dst + 0) = bp0; + *reinterpret_cast(dst + 4) = bp1; + *reinterpret_cast (dst + 8) = scale; + } + } + + // Make writes visible to peer GPUs before the barrier + __threadfence_system(); + + // ------------------------------------------------------------------ + // Phase 2: global barrier — wait for all ranks to finish writing + // ------------------------------------------------------------------ + grid.sync(); + pg_barrier_inner(ctx, device_mask, this_device, master_device, abort_flag); + if (*abort_flag) return; + + // ------------------------------------------------------------------ + // Phase 3: accumulate all rank slots → write result in-place + // ------------------------------------------------------------------ + { + int global_tid = (int)(blockIdx.x * blockDim.x + threadIdx.x); + int stride = (int)(gridDim.x * blockDim.x); + + for (size_t blk = (size_t)global_tid; blk < num_blocks; blk += (size_t)stride) + { + size_t elem_off = blk * TQ3_BLOCK_SIZE; + size_t elems_this_block = min((size_t)TQ3_BLOCK_SIZE, num_elements - elem_off); + + // Accumulate in float for numerical accuracy + float acc[TQ3_BLOCK_SIZE]; + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) acc[i] = 0.0f; + + for (int r = 0; r < num_ranks; ++r) + { + const uint8_t* src = shbuf_ptr + (size_t)r * slot_bytes + blk * block_bytes; + + uint32_t bp0 = *reinterpret_cast(src + 0); + uint32_t bp1 = *reinterpret_cast(src + 4); + half scale = *reinterpret_cast (src + 8); + float fscale = __half2float(scale); + + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + uint32_t mag = (bp1 >> i) & 1u; + uint32_t sign = (bp0 >> i) & 1u; + if (mag) acc[i] += sign ? -fscale : fscale; + } + } + + // Write results back to the data tensor + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + if ((size_t)i < elems_this_block) + data_ptr[elem_off + (size_t)i] = __float2half(acc[i]); + } + } + } + + // ------------------------------------------------------------------ + // Phase 4: second barrier so all ranks have finished reading before + // the caller re-uses or frees the shbuf slot + // ------------------------------------------------------------------ + __threadfence_system(); + grid.sync(); + pg_barrier_inner(ctx, device_mask, this_device, master_device, abort_flag); +} + + +// --------------------------------------------------------------------------- +// Host-side launcher +// --------------------------------------------------------------------------- +void tq3_all_reduce +( + const at::Tensor& data, + uintptr_t ctx_ptr, + std::vector devices, + int this_device, + int master_device, + uintptr_t shbuf, + size_t shbuf_size, + bool contribution +) +{ + TORCH_CHECK(data.scalar_type() == at::kHalf, + "tq3_all_reduce: input tensor must be fp16 (torch.float16)"); + TORCH_CHECK(data.is_contiguous(), + "tq3_all_reduce: input tensor must be contiguous"); + TORCH_CHECK(data.is_cuda(), + "tq3_all_reduce: input tensor must be on a CUDA device"); + + const at::cuda::OptionalCUDAGuard device_guard(this_device); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + pg_check_timeout(ctx_ptr); + PGContext* ctx = reinterpret_cast(ctx_ptr); + + // Build device mask + uint32_t device_mask = 0; + for (uintptr_t d : devices) device_mask |= (1u << (int)d); + + int num_ranks = __builtin_popcount(device_mask); + if (num_ranks <= 1) return; + + size_t num_elements = (size_t) data.numel(); + size_t num_tq3_blocks = (num_elements + TQ3_BLOCK_SIZE - 1) / TQ3_BLOCK_SIZE; + size_t slot_bytes = num_tq3_blocks * 10u; // 10 bytes per TQ3 block + size_t total_needed = slot_bytes * (size_t)num_ranks; + + TORCH_CHECK(total_needed <= shbuf_size, + "tq3_all_reduce: tensor too large for shared buffer. " + "Need ", total_needed, " bytes, have ", shbuf_size); + + uint8_t* shbuf_ptr = reinterpret_cast(shbuf); + + // Thread count: one thread handles one TQ3 block at a time (strided loop). + // Cap at TQ3_AR_MAX_THREADS, round up to warp boundary. + int threads = (int) min((size_t) TQ3_AR_MAX_THREADS, num_tq3_blocks); + if (threads < 1) threads = 1; + threads = ((threads + 31) / 32) * 32; + if (threads > TQ3_AR_MAX_THREADS) threads = TQ3_AR_MAX_THREADS; + + // Single cooperative block per GPU (grid.sync() requires cooperative launch; + // one block is sufficient — threads loop over all TQ3 blocks internally). + dim3 grid_dim(1); + dim3 block_dim(threads); + + // Per-call abort flag — small temporary device tensor + at::Tensor abort_tensor = torch::zeros( + {1}, + at::TensorOptions().dtype(torch::kInt32).device(data.device()) + ); + uint32_t* abort_flag_ptr = reinterpret_cast(abort_tensor.data_ptr()); + + half* data_dev_ptr = reinterpret_cast(data.data_ptr()); + + void* kernelArgs[] = + { + (void*) &ctx, + (void*) &device_mask, + (void*) &this_device, + (void*) &master_device, + (void*) &data_dev_ptr, + (void*) &shbuf_ptr, + (void*) &num_elements, + (void*) &slot_bytes, + (void*) &contribution, + (void*) &abort_flag_ptr + }; + + cudaLaunchCooperativeKernel( + (void*) tq3_all_reduce_kernel, + grid_dim, + block_dim, + kernelArgs, + 0, + stream + ); + + cuda_check(cudaPeekAtLastError()); +} diff --git a/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cuh b/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cuh new file mode 100644 index 00000000..ed81ae64 --- /dev/null +++ b/exllamav3/exllamav3_ext/parallel/tq3_all_reduce.cuh @@ -0,0 +1,33 @@ +#pragma once + +#include + +// TQ3-compressed all-reduce over the native parallel-group shared memory fabric. +// +// Uses an all-gather + local-reduce pattern: +// 1. Each rank TQ3-compresses its fp16 tensor (6.4× smaller) and writes the +// result into its dedicated slot inside the pinned shared buffer. +// 2. A cross-GPU barrier ensures every rank has finished writing. +// 3. Each rank decompresses all slots and accumulates them locally (in-place). +// +// Parameters +// data — fp16 tensor on the calling GPU (modified in-place) +// ctx_ptr — uintptr_t of the process-group's pinned PGContext block +// devices — ordered list of participating GPU indices (same as pg_all_reduce) +// this_device — GPU index of the calling process +// master_device— coordinator GPU index (lowest rank, used by barrier_inner) +// shbuf — uintptr_t of the pinned shared ring buffer (shm_b) +// shbuf_size — total size of the shared buffer in bytes +// contribution — if false this rank is a non-contributing observer: it writes +// all-zeros into its slot but still participates in barriers +void tq3_all_reduce +( + const at::Tensor& data, + uintptr_t ctx_ptr, + std::vector devices, + int this_device, + int master_device, + uintptr_t shbuf, + size_t shbuf_size, + bool contribution +); diff --git a/exllamav3/exllamav3_ext/parallel/tq3_compress.cuh b/exllamav3/exllamav3_ext/parallel/tq3_compress.cuh new file mode 100644 index 00000000..70cfa309 --- /dev/null +++ b/exllamav3/exllamav3_ext/parallel/tq3_compress.cuh @@ -0,0 +1,124 @@ +#pragma once + +#include +#include + +// TQ3: 3-level ternary quantization for fp16 tensors. +// +// Each block of TQ3_BLOCK_SIZE (32) fp16 values is compressed to 10 bytes: +// - 1x fp16 scale (2 bytes) — max absolute value in the block +// - 2x uint32_t (8 bytes) — two bits per element: 00=-1, 01=0, 10=+1 +// (packed MSB→LSB for bp0 at bit 31, bp1 at bit 31) +// +// Encoding: for each element v / scale +// |x| < TQ3_BOUNDARY → ternary 0 (bp bit = 0, sign bit = 0) +// x >= +TQ3_BOUNDARY → ternary +1 (bp bit = 1, sign bit = 0) +// x <= -TQ3_BOUNDARY → ternary +1 magnitude, negative (bp bit = 1, sign bit = 1) +// +// bp0 holds the sign bits (1 = negative non-zero) +// bp1 holds the magnitude bits (1 = non-zero) +// +// Decompression: dst[i] = scale * (bp1_i ? (bp0_i ? -1.0h : +1.0h) : 0.0h) + +#define TQ3_BLOCK_SIZE 32 +#define TQ3_BOUNDARY 0.5f + +// --------------------------------------------------------------------------- +// tq3_compress_block +// +// Compress 32 fp16 values from src[] into a (bp0, bp1, scale) triplet. +// All three output pointers must be writable by the calling thread. +// This is a pure device function — call once per block of 32 elements. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void tq3_compress_block +( + const half* __restrict__ src, + uint32_t* __restrict__ bp0, // sign-bit plane (1 = negative non-zero) + uint32_t* __restrict__ bp1, // magnitude-bit plane (1 = non-zero) + half* __restrict__ scale // max abs value of the block +) +{ + // Pass 1: find max absolute value (in float for precision) + float max_abs = 0.0f; + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + float v = __half2float(src[i]); + float av = (v < 0.0f) ? -v : v; + if (av > max_abs) max_abs = av; + } + + // Store scale (fp16); guard against zero denominator + *scale = __float2half(max_abs); + float inv_scale = (max_abs > 0.0f) ? (1.0f / max_abs) : 0.0f; + + // Pass 2: quantize → pack into two uint32 bit planes (bit 0 = element 0) + uint32_t b0 = 0u; + uint32_t b1 = 0u; + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + float v = __half2float(src[i]) * inv_scale; + float av = (v < 0.0f) ? -v : v; + if (av >= TQ3_BOUNDARY) + { + b1 |= (1u << i); // non-zero + if (v < 0.0f) b0 |= (1u << i); // negative + } + } + + *bp0 = b0; + *bp1 = b1; +} + +// --------------------------------------------------------------------------- +// tq3_decompress_block +// +// Reconstruct 32 fp16 values from (bp0, bp1, scale) into dst[]. +// --------------------------------------------------------------------------- +__device__ __forceinline__ void tq3_decompress_block +( + uint32_t bp0, + uint32_t bp1, + half scale, + half* __restrict__ dst +) +{ + float fscale = __half2float(scale); + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + uint32_t mag = (bp1 >> i) & 1u; + uint32_t sign = (bp0 >> i) & 1u; + float v = 0.0f; + if (mag) v = sign ? -fscale : fscale; + dst[i] = __float2half(v); + } +} + +// --------------------------------------------------------------------------- +// tq3_decompress_add_block +// +// Fused decompress + accumulate: dst[i] += decompressed[i] +// --------------------------------------------------------------------------- +__device__ __forceinline__ void tq3_decompress_add_block +( + uint32_t bp0, + uint32_t bp1, + half scale, + half* __restrict__ dst +) +{ + float fscale = __half2float(scale); + #pragma unroll + for (int i = 0; i < TQ3_BLOCK_SIZE; ++i) + { + uint32_t mag = (bp1 >> i) & 1u; + uint32_t sign = (bp0 >> i) & 1u; + if (mag) + { + float contrib = sign ? -fscale : fscale; + dst[i] = __float2half(__half2float(dst[i]) + contrib); + } + } +} diff --git a/exllamav3/model/model_tp_backend.py b/exllamav3/model/model_tp_backend.py index adbfecfc..c1771892 100644 --- a/exllamav3/model/model_tp_backend.py +++ b/exllamav3/model/model_tp_backend.py @@ -36,6 +36,7 @@ def __init__( master: bool, uuid: str, shbuf_size: int = SHBUF_SIZE, + tq3_compress: bool = False, ): self.device = device if device < 0: diff --git a/tests/bench_tq3_allreduce.py b/tests/bench_tq3_allreduce.py new file mode 100644 index 00000000..4afb718e --- /dev/null +++ b/tests/bench_tq3_allreduce.py @@ -0,0 +1,144 @@ +""" +Benchmark TQ3 compressed all-reduce overhead. + +Break-even: TQ3 wins when compress+decompress time < bandwidth savings. + PCIe Gen4 x16 (32 GB/s): break-even at ~26 ns/byte overhead + InfiniBand HDR (25 GB/s): break-even at ~34 ns/byte overhead + Ethernet 100G (12.5 GB/s): break-even at ~67 ns/byte overhead +""" +import torch +import time +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def tq3_compress_py(x): + """Pure PyTorch TQ3 compression fallback.""" + blocks = x.float().view(-1, 32) + scales = blocks.abs().max(dim=1, keepdim=True).values.clamp(min=1e-10) + normalized = blocks / scales + nonzero = (normalized.abs() >= 0.5).int() + positive = ((normalized > 0) & (nonzero == 1)).int() + bit_idx = torch.arange(32, device=x.device).unsqueeze(0) + bp0 = (nonzero << bit_idx).sum(dim=1).to(torch.int32) + bp1 = (positive << bit_idx).sum(dim=1).to(torch.int32) + return bp0, bp1, scales.squeeze(1).half() + + +def tq3_decompress_py(bp0, bp1, scales, numel): + """Pure PyTorch TQ3 decompression fallback.""" + bit_idx = torch.arange(32, device=bp0.device).unsqueeze(0) + nz = ((bp0.unsqueeze(1) >> bit_idx) & 1).float() + pos = ((bp1.unsqueeze(1) >> bit_idx) & 1).float() + ternary = nz * (2.0 * pos - 1.0) + result = (ternary * scales.float().unsqueeze(1)).reshape(-1) + return result[:numel].half() + + +def bench_tq3_compress_decompress(): + if not torch.cuda.is_available(): + print("CUDA not available, skipping benchmark") + return + + print("=" * 70) + print("TQ3 Compressed All-Reduce Bandwidth Analysis") + print("=" * 70) + + # Attempt to load the CUDA extension once and report which path is active. + use_ext = False + try: + from exllamav3.ext import exllamav3_ext as ext + # Probe for the expected symbols so we fail fast rather than at loop time. + _ = ext.quant_tq3_cache_cont + _ = ext.dequant_tq3_cache_cont + use_ext = True + print("Backend: CUDA extension (exllamav3_ext)") + except (ImportError, AttributeError): + print("Backend: pure PyTorch fallback") + + sizes = [4096, 8192, 16384, 32768, 65536, 131072] + warmup = 20 + iters = 100 + + print(f"\n{'Size':>8} | {'Compress':>10} | {'Decompress':>10} | {'Ratio':>6} | {'Break-even BW':>14}") + print("-" * 70) + + for size in sizes: + x = torch.randn(size, dtype=torch.float16, device='cuda') + num_blocks = size // 32 + + if use_ext: + from exllamav3.ext import exllamav3_ext as ext + packed = torch.empty(num_blocks * 2, dtype=torch.int32, device='cuda') + scales = torch.empty(num_blocks, dtype=torch.float16, device='cuda') + output = torch.empty_like(x) + + # Warmup + for _ in range(warmup): + ext.quant_tq3_cache_cont(x, packed, scales) + ext.dequant_tq3_cache_cont(packed, scales, output) + torch.cuda.synchronize() + + # Compress + t0 = time.perf_counter() + for _ in range(iters): + ext.quant_tq3_cache_cont(x, packed, scales) + torch.cuda.synchronize() + compress_ms = (time.perf_counter() - t0) / iters * 1000 + + # Decompress + t0 = time.perf_counter() + for _ in range(iters): + ext.dequant_tq3_cache_cont(packed, scales, output) + torch.cuda.synchronize() + decompress_ms = (time.perf_counter() - t0) / iters * 1000 + + else: + # PyTorch fallback — benchmark the pure-Python path. + # Warmup + for _ in range(warmup): + bp0, bp1, sc = tq3_compress_py(x) + _ = tq3_decompress_py(bp0, bp1, sc, size) + torch.cuda.synchronize() + + # Compress + t0 = time.perf_counter() + for _ in range(iters): + bp0, bp1, sc = tq3_compress_py(x) + torch.cuda.synchronize() + compress_ms = (time.perf_counter() - t0) / iters * 1000 + + # Decompress (reuse last compressed result) + t0 = time.perf_counter() + for _ in range(iters): + _ = tq3_decompress_py(bp0, bp1, sc, size) + torch.cuda.synchronize() + decompress_ms = (time.perf_counter() - t0) / iters * 1000 + + fp16_bytes = size * 2 + tq3_bytes = num_blocks * 10 # 2 x int32 (4B each) + 1 x fp16 (2B) per block + ratio = fp16_bytes / tq3_bytes + + total_ms = compress_ms + decompress_ms + saved_bytes = fp16_bytes - tq3_bytes + if total_ms > 0: + breakeven_bw_gbs = saved_bytes / (total_ms / 1000) / 1e9 + else: + breakeven_bw_gbs = float('inf') + + print( + f"{size:>8} | {compress_ms:>8.3f}ms | {decompress_ms:>8.3f}ms" + f" | {ratio:>5.1f}x | {breakeven_bw_gbs:>10.1f} GB/s" + ) + + print(f"\nInterpretation:") + print(f" TQ3 wins when your interconnect is SLOWER than the break-even BW.") + print(f" PCIe Gen4 x16: 32 GB/s -> TQ3 wins if break-even > 32 GB/s") + print(f" InfiniBand HDR: 25 GB/s -> TQ3 wins if break-even > 25 GB/s") + print(f" NVLink (A100): 600 GB/s -> TQ3 unlikely to win") + + +if __name__ == "__main__": + bench_tq3_compress_decompress() diff --git a/tests/test_tq3_allreduce.py b/tests/test_tq3_allreduce.py new file mode 100644 index 00000000..32a044c6 --- /dev/null +++ b/tests/test_tq3_allreduce.py @@ -0,0 +1,132 @@ +""" +Tests for TQ3 compressed all-reduce. + +Since actual multi-GPU all-reduce requires multiple processes, +these tests simulate the compression quality and verify that +TQ3 compress->sum->decompress produces acceptable results. +""" +import torch +import pytest +import math +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def sqnr(original, reconstructed): + sig = (original.float() ** 2).mean() + noise = ((original.float() - reconstructed.float()) ** 2).mean() + if noise < 1e-20: + return float('inf') + return 10 * math.log10(sig / noise) + + +def tq3_compress_py(x): + """Pure PyTorch TQ3 compression (no CUDA ext needed).""" + assert x.numel() % 32 == 0 + blocks = x.float().view(-1, 32) + scales = blocks.abs().max(dim=1, keepdim=True).values.clamp(min=1e-10) + normalized = blocks / scales + nonzero = (normalized.abs() >= 0.5).int() + positive = ((normalized > 0) & (nonzero == 1)).int() + bit_idx = torch.arange(32, device=x.device).unsqueeze(0) + bp0 = (nonzero << bit_idx).sum(dim=1).to(torch.int32) + bp1 = (positive << bit_idx).sum(dim=1).to(torch.int32) + return bp0, bp1, scales.squeeze(1).half() + + +def tq3_decompress_py(bp0, bp1, scales, numel): + """Pure PyTorch TQ3 decompression.""" + bit_idx = torch.arange(32, device=bp0.device).unsqueeze(0) + nz = ((bp0.unsqueeze(1) >> bit_idx) & 1).float() + pos = ((bp1.unsqueeze(1) >> bit_idx) & 1).float() + ternary = nz * (2.0 * pos - 1.0) + result = (ternary * scales.float().unsqueeze(1)).reshape(-1) + return result[:numel].half() + + +class TestTQ3CompressedAllReduceSimulation: + + def test_single_rank_roundtrip(self): + torch.manual_seed(42) + x = torch.randn(4096, dtype=torch.float16, device='cuda' if torch.cuda.is_available() else 'cpu') + bp0, bp1, scales = tq3_compress_py(x) + recovered = tq3_decompress_py(bp0, bp1, scales, x.numel()) + ratio = sqnr(x, recovered) + assert ratio >= 6.0, f"SQNR {ratio:.2f} dB < 6 dB" + + def test_simulated_4rank_allreduce(self): + torch.manual_seed(42) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + num_ranks = 4 + tensors = [torch.randn(4096, dtype=torch.float16, device=dev) for _ in range(num_ranks)] + exact_sum = sum(t.float() for t in tensors).half() + + compressed_sum = torch.zeros(4096, dtype=torch.float32, device=dev) + for t in tensors: + bp0, bp1, scales = tq3_compress_py(t) + decompressed = tq3_decompress_py(bp0, bp1, scales, t.numel()) + compressed_sum += decompressed.float() + compressed_sum = compressed_sum.half() + + ratio = sqnr(exact_sum, compressed_sum) + assert ratio >= 4.0, f"4-rank compressed SQNR {ratio:.2f} dB" + print(f"4-rank compressed all-reduce SQNR: {ratio:.2f} dB") + + def test_simulated_8rank_allreduce(self): + torch.manual_seed(42) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + num_ranks = 8 + tensors = [torch.randn(8192, dtype=torch.float16, device=dev) for _ in range(num_ranks)] + exact_sum = sum(t.float() for t in tensors).half() + + compressed_sum = torch.zeros(8192, dtype=torch.float32, device=dev) + for t in tensors: + bp0, bp1, scales = tq3_compress_py(t) + decompressed = tq3_decompress_py(bp0, bp1, scales, t.numel()) + compressed_sum += decompressed.float() + compressed_sum = compressed_sum.half() + + ratio = sqnr(exact_sum, compressed_sum) + assert ratio >= 3.0, f"8-rank compressed SQNR {ratio:.2f} dB" + print(f"8-rank compressed all-reduce SQNR: {ratio:.2f} dB") + + def test_bandwidth_ratio(self): + """Verify TQ3 achieves expected compression ratio.""" + numel = 8192 + fp16_bytes = numel * 2 + num_blocks = numel // 32 + tq3_bytes = num_blocks * (4 + 4 + 2) # 2 uint32 + 1 fp16 per block + ratio = fp16_bytes / tq3_bytes + assert ratio >= 6.0, f"Compression ratio {ratio:.2f}x < 6x" + print(f"TQ3 compression ratio: {ratio:.2f}x ({fp16_bytes} -> {tq3_bytes} bytes)") + + def test_zeros(self): + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.zeros(1024, dtype=torch.float16, device=dev) + bp0, bp1, scales = tq3_compress_py(x) + recovered = tq3_decompress_py(bp0, bp1, scales, x.numel()) + assert torch.allclose(recovered.float(), x.float(), atol=1e-3) + + def test_sign_preservation(self): + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.tensor([1.0, -1.0, 0.5, -0.5] * 8, dtype=torch.float16, device=dev) + bp0, bp1, scales = tq3_compress_py(x) + recovered = tq3_decompress_py(bp0, bp1, scales, x.numel()) + # Signs should be preserved for large values + for i in [0, 1]: + assert (recovered[i] > 0) == (x[i] > 0), f"Sign mismatch at index {i}" + + def test_large_tensor(self): + torch.manual_seed(42) + dev = 'cuda' if torch.cuda.is_available() else 'cpu' + x = torch.randn(131072, dtype=torch.float16, device=dev) # 128K values + bp0, bp1, scales = tq3_compress_py(x) + recovered = tq3_decompress_py(bp0, bp1, scales, x.numel()) + ratio = sqnr(x, recovered) + assert ratio >= 5.0, f"128K SQNR {ratio:.2f} dB < 5 dB" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"])