Skip to content

Commit 2d025dc

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Back out "Enable fast FP8 GEMM for memory bound" (#3588)
Summary: Pull Request resolved: #3588 X-link: facebookresearch/FBGEMM#672 Original commit changeset: fbf34e283e94 Original Phabricator Diff: D68193920 Reviewed By: catalinii Differential Revision: D68351266 fbshipit-source-id: ba05114f2f44264232306f522dfbb56b507b1873
1 parent 379db5f commit 2d025dc

File tree

4 files changed

+0
-333
lines changed

4 files changed

+0
-333
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

-34
Original file line numberDiff line numberDiff line change
@@ -716,40 +716,6 @@ def cuda(self) -> bool:
716716
return True
717717

718718

719-
@register_quantize_op
720-
class FP8LiteGemm(QuantizeOpBase):
721-
"""
722-
FP8 lite matmul for memory bound.
723-
"""
724-
725-
def quantize(self, x, w):
726-
# Quantize both input tensors.
727-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
728-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
729-
return xq, wq, x_scale, w_scale
730-
731-
def compute(self, xq, wq, x_scale, w_scale):
732-
return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
733-
734-
def quantize_and_compute(self, x, w):
735-
xq, wq, x_scale, w_scale = self.quantize(x, w)
736-
return self.compute(xq, wq, x_scale * w_scale)
737-
738-
@property
739-
def name(self) -> str:
740-
return "cuda_lite"
741-
742-
@property
743-
def hip(self) -> bool:
744-
# Need to add support for better quantize kernel.
745-
# Also may have an issue with cuda graphs.
746-
return False
747-
748-
@property
749-
def cuda(self) -> bool:
750-
return True
751-
752-
753719
@register_quantize_op
754720
class TritonFP8RowwiseGemm(QuantizeOpBase):
755721
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_lite.cu

-263
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

-12
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ at::Tensor f8f8bf16_tensorwise(
5555
at::Tensor WQ,
5656
double scale,
5757
bool use_fast_accum = true);
58-
at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale);
5958
std::vector<at::Tensor> f8f8bf16_grouped(
6059
at::TensorList XQ,
6160
at::TensorList WQ,
@@ -188,7 +187,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
188187
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
189188
m.def(
190189
"f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]");
191-
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
192190
m.def(
193191
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
194192
m.def(
@@ -270,7 +268,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
270268
m.impl("f8f8bf16", f8f8bf16);
271269
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
272270
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
273-
m.impl("f8f8bf16_lite", f8f8bf16_lite);
274271
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
275272
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
276273
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
@@ -298,7 +295,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
298295
m.impl("f8f8bf16", f8f8bf16);
299296
m.impl("f8f8bf16_cublas", f8f8bf16_cublas);
300297
m.impl("f8f8bf16_grouped", f8f8bf16_grouped);
301-
m.impl("f8f8bf16_lite", f8f8bf16_lite);
302298
m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise);
303299
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched);
304300
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
@@ -419,13 +415,6 @@ at::Tensor f8f8bf16_tensorwise_meta(
419415
return Y;
420416
}
421417

422-
at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
423-
const at::SymInt M = X.sym_size(0);
424-
const at::SymInt N = W.sym_size(0);
425-
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
426-
return Y;
427-
}
428-
429418
at::Tensor f8i4bf16_rowwise_meta(
430419
at::Tensor XQ, // FP8
431420
at::Tensor WQ, // INT4
@@ -544,7 +533,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
544533
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta);
545534
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
546535
m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta);
547-
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
548536
#endif
549537
}
550538

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

-24
Original file line numberDiff line numberDiff line change
@@ -1110,30 +1110,6 @@ def test_quantize_zero_input(self, K) -> None:
11101110
torch.testing.assert_close(w.shape, wq.shape)
11111111
torch.testing.assert_close(w_scale.shape, w_scale_ref.shape)
11121112

1113-
@unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.")
1114-
@settings(deadline=None)
1115-
@given(
1116-
M=st.sampled_from([1, 5, 16]),
1117-
N=st.sampled_from([1024, 6144]),
1118-
K=st.sampled_from([512, 3584]),
1119-
CudaGraph=st.sampled_from([True, False]),
1120-
)
1121-
def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None:
1122-
x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1
1123-
w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01
1124-
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
1125-
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
1126-
if CudaGraph:
1127-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1128-
g = torch.cuda.CUDAGraph()
1129-
with torch.cuda.graph(g):
1130-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1131-
g.replay()
1132-
else:
1133-
zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale)
1134-
zq_ref = (x @ w.T).to(torch.bfloat16)
1135-
torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2)
1136-
11371113

11381114
if __name__ == "__main__":
11391115
unittest.main()

0 commit comments

Comments
 (0)