Skip to content

Commit 4593359

Browse files
committed
Add dequantize flow so that we round trip
from fp16 weights to fp16 dequantized weights for comparison. * WIP: Use small test size to narrow down numerics issue beyond first vector store
1 parent 4670354 commit 4593359

File tree

1 file changed

+120
-103
lines changed

1 file changed

+120
-103
lines changed

tests/python/tir-base/test_native_fp8.py

Lines changed: 120 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,46 @@ def create_quantize_func(
7676
return bb.finalize()
7777

7878

79+
def create_dequantize_func(
80+
packed_weight_shape,
81+
scale_shape,
82+
dequantized_shape,
83+
model_dtype,
84+
quantize_dtype,
85+
storage_dtype,
86+
group_size,
87+
num_elem_per_storage,
88+
axis,
89+
) -> IRModule:
90+
if DataType(quantize_dtype).type_code == DataTypeCode.E4M3Float:
91+
dequantize_func = dequantize_fp8x4_e4m3
92+
else:
93+
assert NotImplementedError()
94+
95+
bb = relax.BlockBuilder() # pylint: disable=invalid-name
96+
packed_weight_var = relax.Var(
97+
"weight", relax.TensorStructInfo(packed_weight_shape, storage_dtype)
98+
)
99+
scale_var = relax.Var("scale", relax.TensorStructInfo(scale_shape, model_dtype))
100+
compute_dequantize = dequantize_func(
101+
packed_weight_shape,
102+
scale_shape,
103+
dequantized_shape,
104+
model_dtype,
105+
quantize_dtype,
106+
storage_dtype,
107+
group_size,
108+
num_elem_per_storage,
109+
axis,
110+
)
111+
with bb.function(name="main", params=[packed_weight_var, scale_var]):
112+
with bb.dataflow():
113+
lv = compute_dequantize(bb, (packed_weight_var, scale_var))
114+
gv = bb.emit_output(lv)
115+
bb.emit_func_output(gv)
116+
return bb.finalize()
117+
118+
79119
def quantize_fp8x4_e4m3( # pylint: disable=too-many-locals
80120
weight_shape: List[tir.PrimExpr],
81121
model_dtype,
@@ -135,9 +175,6 @@ def compute_quantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr):
135175
quantize_dtype,
136176
)
137177
# quant.show()
138-
# import ipdb
139-
140-
# ipdb.set_trace()
141178

142179
global_var = bb.add_func(quant, "quantized_weight")
143180
lv_quantized_weight = bb.emit(
@@ -161,6 +198,41 @@ def compute_transpose(quantized_weight: te.Tensor, scale: te.Tensor):
161198
return compute_scale, compute_quantize_weight, compute_transpose
162199

163200

201+
def dequantize_fp8x4_e4m3( # pylint: disable=too-many-locals
202+
packed_weight_shape: List[tir.PrimExpr],
203+
scale_shape,
204+
dequant_shape,
205+
model_dtype,
206+
quantize_dtype,
207+
storage_dtype,
208+
group_size,
209+
num_elem_per_storage,
210+
axis: int = -1,
211+
) -> Tuple[te.Tensor, te.Tensor]:
212+
"""Group quantization for weight tensor, defined in tensor expression."""
213+
axis = axis if axis >= 0 else len(shape) + axis
214+
215+
def compute_dequantize_weight(bb: relax.BlockBuilder, args: relax.expr.Expr):
216+
dequant = dequant_fp8x4_e4m3_sm90(
217+
packed_weight_shape,
218+
scale_shape,
219+
dequant_shape,
220+
group_size,
221+
axis,
222+
model_dtype,
223+
storage_dtype,
224+
quantize_dtype,
225+
)
226+
227+
global_var = bb.add_func(dequant, "dequantize_weight")
228+
lv_dequantized_weight = bb.emit(
229+
relax.call_tir(global_var, args, relax.TensorStructInfo(dequant_shape, model_dtype))
230+
)
231+
return lv_dequantized_weight
232+
233+
return compute_dequantize_weight
234+
235+
164236
def quant_and_pack_fp8x4_e4m3_sm90(
165237
weight_shape,
166238
packed_shape,
@@ -175,6 +247,7 @@ def quant_and_pack_fp8x4_e4m3_sm90(
175247
vec_quantized_dtype = f"{quantized_dtype}x{vector_length}"
176248
vec_model_dtype = f"{model_dtype}x{vector_length}"
177249
num_elem_per_storage = vector_length
250+
# TODO(csullivan) assert on storage dtype / quantize type bytes == vector length
178251
assert (
179252
group_size % vector_length == 0
180253
), f"Number of elements in a group must be divisible by fp8 vector length {vector_length}"
@@ -202,14 +275,11 @@ def quant_pack(
202275
storage_dtype,
203276
T.Cast(
204277
vec_quantized_dtype,
205-
# Note: Using the colon here is a sugared way of writing T.ramp(v_i1, 1, vector_length)
206-
# ie a vector load of A
207-
A[v_i0, v_i1 : v_i1 + vector_length]
278+
A[v_i0, T.ramp(v_i1 * vector_length, 1, vector_length)]
208279
/ scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)],
209280
),
210281
)
211282

212-
quant_pack.show()
213283
return quant_pack
214284

215285

@@ -251,7 +321,6 @@ def dequant(
251321
scale[v_i0, v_i1 * T.int64(vector_length) // T.int64(group_size)], vector_length
252322
)
253323

254-
dequant.show()
255324
return dequant
256325

257326

@@ -543,100 +612,21 @@ def add(
543612
tvm.testing.assert_allclose(c.numpy(), c_expected, atol=1e-5, rtol=1e-5)
544613

545614

546-
@tvm.testing.requires_cuda_compute_version(8)
547-
def test_weight_scale():
548-
weight_shape = [32000, 4096]
549-
group_size = 32
550-
axis = 1
551-
scale_shape = [d // group_size if axis == i else d for i, d in enumerate(weight_shape)]
552-
model_dtype = "float16"
553-
storage_dtype = "uint32"
554-
quantized_dtype = "e4m3_float8"
555-
556-
# q_weight = fp8(weight_f16 / scale_f16)
557-
# q_weight = fp8x4(weight_f16x4 / scale_f16x4)
558-
vector_length = 4
559-
vec_quantized_dtype = "e4m3_float8x4"
560-
vec_model_dtype = "float16x4"
561-
num_el_per_storage = 4
562-
563-
@T.prim_func
564-
def vectorized(
565-
A: T.Buffer(weight_shape, model_dtype),
566-
scale: T.Buffer(scale_shape, model_dtype),
567-
compute: T.Buffer(
568-
(T.int64(weight_shape[0]), T.int64(weight_shape[1] // num_el_per_storage)),
569-
storage_dtype,
570-
),
571-
):
572-
T.func_attr({"tir.noalias": T.bool(True)})
573-
# with T.block("root"):
574-
# test = T.alloc_buffer(1, dtype=vec_model_dtype, scope="local")
575-
for i0, i1 in T.grid(T.int64(weight_shape[0]), T.int64(weight_shape[1])):
576-
with T.block("compute"):
577-
v_i0 = T.axis.spatial(T.int64(weight_shape[0]), i0)
578-
v_i1 = T.axis.spatial(T.int64(weight_shape[1] // vector_length), i1)
579-
T.reads(
580-
A[v_i0, v_i1 : v_i1 + vector_length], scale[v_i0, v_i1 // T.int64(group_size)]
581-
)
582-
T.writes(compute[v_i0, v_i1 * vector_length])
583-
compute[v_i0, v_i1 * vector_length] = T.reinterpret(
584-
storage_dtype,
585-
T.Cast(
586-
vec_quantized_dtype,
587-
# Note: Using the colon here is a sugared way of writing T.ramp(v_i1, 1, vector_length)
588-
# ie a vector load of A
589-
A[v_i0, v_i1 : v_i1 + vector_length]
590-
/ scale[v_i0, v_i1 // T.int64(group_size)],
591-
),
592-
)
593-
594-
sch = tvm.tir.Schedule(vectorized)
595-
block = sch.get_block("compute")
596-
loops = sch.get_loops(block)
597-
txo, txi = sch.split(loops[0], factors=[None, 256])
598-
sch.bind(loops[1], "blockIdx.x")
599-
sch.bind(txi, "threadIdx.x")
600-
sch.mod.show()
601-
602-
# sch = tvm.tir.Schedule(main)
603-
# block = sch.get_block("compute")
604-
# loops = sch.get_loops(block)
605-
# bx, tx, lanes = sch.split(loops[-1], factors=[None, 32, 4])
606-
# w_l = sch.cache_read(block, 0, storage_scope="local")
607-
# # s_l = sch.cache_read(block, 1, storage_scope="local")
608-
# sch.compute_at(block=w_l, loop=tx)
609-
# # sch.compute_at(block=s_l, loop=tx)
610-
# sch.bind(bx, "blockIdx.x")
611-
# sch.bind(tx, "threadIdx.x")
612-
# # sch.vectorize(lanes)
613-
# sch.mod.show()
614-
615-
import ipdb
616-
617-
ipdb.set_trace()
618-
target = "cuda"
619-
f = tvm.build(sch.mod, target=target)
620-
print(f.imported_modules[0].get_source())
621-
622-
623-
weight_shape = tvm.testing.parameter([32000, 4096], [4096, 14336])
615+
weight_shape = tvm.testing.parameter((32000, 4096), (4096, 14336))
624616

625617

626618
@tvm.testing.requires_cuda_compute_version(8)
627-
def test_fp8_e4_quant_weight(weight_shape):
619+
def test_fp8e4x4_quant_dequant_weight(weight_shape):
628620
group_size = 32
629621
axis = 1
630622
scale_shape = [d // group_size if axis == i else d for i, d in enumerate(weight_shape)]
631623
model_dtype = "float16"
632624
storage_dtype = "uint32"
633625
quantize_dtype = "e4m3_float8"
634626
num_el_per_storage = 4
627+
max_int_value = 448
635628

636-
# TODO(csullivan): check this
637-
max_int_value = 448 if "e4m3" in quantize_dtype else 57344
638-
639-
mod = create_quantize_func(
629+
quant_mod = create_quantize_func(
640630
weight_shape,
641631
model_dtype,
642632
quantize_dtype,
@@ -647,33 +637,60 @@ def test_fp8_e4_quant_weight(weight_shape):
647637
axis,
648638
output_transpose=False,
649639
)
640+
# quant_mod.show()
650641

651642
target_str = "cuda"
652643
target = tvm.target.Target(target_str)
653644
dev = tvm.device(target_str, 0)
654645
with target:
655-
mod = dl.ApplyDefaultSchedule(
646+
quant_mod = dl.ApplyDefaultSchedule(
656647
dl.gpu.Reduction(),
657648
dl.gpu.GeneralReduction(),
658649
dl.gpu.Fallback(),
659-
)(mod)
660-
661-
mod.show()
662-
663-
f = tvm.build(mod["compute_scale"], target=target)
664-
cuda_src = f.imported_modules[0].get_source()
665-
print(cuda_src)
650+
)(quant_mod)
651+
ex = relax.build(quant_mod, target=target)
652+
vm = relax.VirtualMachine(ex, dev)
666653

667-
ex = relax.build(mod, target=target)
668-
669-
vm = relax.VirtualMachine(ex, dev) # pylint: disable=invalid-name
654+
def print_cuda(target, mod, name=None):
655+
if name:
656+
mod = mod[name]
657+
f = tvm.build(mod, target=target)
658+
cuda_src = f.imported_modules[0].get_source()
659+
print(cuda_src)
670660

671661
weight_np = np.random.uniform(-100, 100, weight_shape).astype(model_dtype)
672662
weight = tvm.nd.array(weight_np, device=dev)
673663
quant_weight, scales = vm["main"](weight)
674664
quant_weight_np, scales_np = quant_weight.numpy(), scales.numpy()
675665

676-
print(quant_weight_np, scales_np)
666+
dequant_mod = create_dequantize_func(
667+
quant_weight.shape,
668+
scales.shape,
669+
weight.shape,
670+
model_dtype,
671+
quantize_dtype,
672+
storage_dtype,
673+
group_size,
674+
num_el_per_storage,
675+
axis,
676+
)
677+
# dequant_mod.show()
678+
679+
with target:
680+
dequant_mod = dl.ApplyDefaultSchedule(
681+
dl.gpu.Reduction(),
682+
dl.gpu.GeneralReduction(),
683+
dl.gpu.Fallback(),
684+
)(dequant_mod)
685+
dequant_mod.show()
686+
687+
print_cuda(target, dequant_mod, name="dequant")
688+
689+
ex = relax.build(dequant_mod, target=target)
690+
vm = relax.VirtualMachine(ex, dev)
691+
dequant_weight = vm["main"](quant_weight, scales)
692+
dequant_weight_np = dequant_weight.numpy()
693+
tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2)
677694

678695

679696
if __name__ == "__main__":

0 commit comments

Comments
 (0)