Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Linear Algebra Ops
:nosignatures:

dot
dot_scaled


Memory/Pointer Ops
Expand Down
9 changes: 5 additions & 4 deletions include/triton/Dialect/Triton/IR/TritonAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,16 @@ def TT_InputPrecisionAttr : I32EnumAttr<
let cppNamespace = "::mlir::triton";
}

// Type for F8F6F4 kind of floats.
def TT_F8F6F4TypeAttr : I32EnumAttr<
"F8F6F4Type", "",
// Type for ScaleDotElemType kind of floats.
def TT_ScaleDotElemTypeAttr : I32EnumAttr<
"ScaleDotElemType", "",
[
I32EnumAttrCase<"E4M3", 0, "e4m3">,
I32EnumAttrCase<"E5M2", 1, "e5m2">,
I32EnumAttrCase<"E2M3", 2, "e2m3">,
I32EnumAttrCase<"E3M2", 3, "e3m2">,
I32EnumAttrCase<"E2M1", 4, "e2m1">
I32EnumAttrCase<"E2M1", 4, "e2m1">,
I32EnumAttrCase<"BF16", 5, "bf16">

]>{
let cppNamespace = "::mlir::triton";
Expand Down
16 changes: 8 additions & 8 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -685,15 +685,15 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,

let arguments = (
ins
// inputs are integer types as they are packed types and we currently
// don't have a representation for those.
TT_IntTensor:$lhs,
TT_IntTensor:$rhs,
// inputs are floats if we have a type for them, otherwise (fp4),
// they are packed in pairs in an I8Tensor
RankedTensorOf<[TT_Float,I8]>:$lhs,
RankedTensorOf<[TT_Float,I8]>:$rhs,
TT_FloatTensor:$c,
TT_IntTensor:$lhs_scale,
Optional<TT_IntTensor>:$rhs_scale,
TT_F8F6F4TypeAttr:$lhs_type,
TT_F8F6F4TypeAttr:$rhs_type
RankedTensorOf<[I8]>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
);

let results = (outs TT_FloatTensor:$d);
Expand Down
2 changes: 1 addition & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure, DeclareOpInterfaceMethods<In
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_F8F6F4TypeAttr:$fp_type);
TT_ScaleDotElemTypeAttr:$fp_type);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ LogicalResult UpcastMXFPOp::verify() {
"operands must have the same number of dimensions, at least 2");
}

if (!(fpType == F8F6F4Type::E2M1 || fpType == F8F6F4Type::E4M3 ||
fpType == F8F6F4Type::E5M2)) {
if (!(fpType == ScaleDotElemType::E2M1 || fpType == ScaleDotElemType::E4M3 ||
fpType == ScaleDotElemType::E5M2)) {
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
}

// Change to support fp8 types
const auto elems_packed = fpType == F8F6F4Type::E2M1 ? 2 : 1;
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;

if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
return emitOpError("last dimension of first operand must be 16 times "
Expand Down Expand Up @@ -93,7 +93,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
return emitOptionalError(loc, "expected a dotOperand encoding");
}

if (typeEncoded == F8F6F4Type::E2M1) {
if (typeEncoded == ScaleDotElemType::E2M1) {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
Expand Down
61 changes: 26 additions & 35 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,22 +415,12 @@ class ScaledBlockedToMMAv2
auto aType = dotOp.getLhsType();
auto bType = dotOp.getRhsType();

auto enumToType = [&rewriter](F8F6F4Type type) {
switch (type) {
case F8F6F4Type::E4M3:
return rewriter.getFloat8E4M3FNType();
case F8F6F4Type::E5M2:
return rewriter.getFloat8E5M2Type();
default:
llvm_unreachable("unexpected type");
}
};

assert((aType == F8F6F4Type::E4M3 || aType == F8F6F4Type::E5M2 ||
aType == F8F6F4Type::E2M1) &&
assert((aType == ScaleDotElemType::E4M3 ||
aType == ScaleDotElemType::E5M2 ||
aType == ScaleDotElemType::E2M1) &&
"NYI: lhs supports fp4 or fp8");
assert(bType == F8F6F4Type::E4M3 ||
bType == F8F6F4Type::E5M2 && "NYI: rhs supports fp8");
assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 ||
bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16");

// TODO run accelerate matmul on A and B first to choose their layouts
// Set return type
Expand All @@ -454,11 +444,12 @@ class ScaledBlockedToMMAv2
auto newAcc =
rewriter.create<ConvertLayoutOp>(oldAcc.getLoc(), newRetType, oldAcc);

auto toMMABf16 = [&newRetType, &rewriter, &ctx, &enumToType](
TypedValue<RankedTensorType> v, int idx,
F8F6F4Type type) -> TypedValue<RankedTensorType> {
auto toMMABf16 =
[&newRetType, &rewriter,
&ctx](TypedValue<RankedTensorType> v, int idx,
ScaleDotElemType type) -> TypedValue<RankedTensorType> {
auto vType = v.getType();
if (type == F8F6F4Type::E2M1) {
if (type == ScaleDotElemType::E2M1) {
// A bit too dynamically typed...
// perhaps return ints in both cases?

Expand All @@ -469,23 +460,23 @@ class ScaledBlockedToMMAv2
vType.getShape(), vType.getElementType(), newVEncoding);
return rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);
} else {
assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3);
assert(type == ScaleDotElemType::E5M2 ||
type == ScaleDotElemType::E4M3 ||
type == ScaleDotElemType::BF16);
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), /*kWidth=*/8);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ConvertLayoutOp>(v.getLoc(), newVType, v);

// Bitcast
auto vTypeFp8 = RankedTensorType::get(vType.getShape(),
enumToType(type), newVEncoding);
v = cast<TypedValue<RankedTensorType>>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
if (type == ScaleDotElemType::BF16) {
return v;
} else {
// Convert to bf16
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
}
}
};
a = toMMABf16(a, 0, aType);
Expand Down Expand Up @@ -515,11 +506,11 @@ class ScaledBlockedToMMAv2
auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get(
ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, CTALayout);

auto newScaleType = RankedTensorType::get(scale.getType().getShape(),
scale.getType().getElementType(),
newScaleEncoding);
scale =
rewriter.create<ConvertLayoutOp>(scale.getLoc(), newScaleType, scale);
auto newScaleDotElemType = RankedTensorType::get(
scale.getType().getShape(), scale.getType().getElementType(),
newScaleEncoding);
scale = rewriter.create<ConvertLayoutOp>(scale.getLoc(),
newScaleDotElemType, scale);

auto scaledA = rewriter.create<triton::gpu::UpcastMXFPOp>(
dotOp.getLoc(), a, scale, dotOp.getLhsType());
Expand Down
19 changes: 10 additions & 9 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,13 @@ void init_triton_ir(py::module &&m) {
.value("IEEE", InputPrecision::IEEE)
.export_values();

py::enum_<F8F6F4Type>(m, "F8F6F4TY", py::module_local())
.value("E4M3", F8F6F4Type::E4M3)
.value("E5M2", F8F6F4Type::E5M2)
.value("E2M3", F8F6F4Type::E2M3)
.value("E3M2", F8F6F4Type::E3M2)
.value("E2M1", F8F6F4Type::E2M1)
py::enum_<ScaleDotElemType>(m, "ScaleDotElemTypeTY", py::module_local())
.value("E4M3", ScaleDotElemType::E4M3)
.value("E5M2", ScaleDotElemType::E5M2)
.value("E2M3", ScaleDotElemType::E2M3)
.value("E3M2", ScaleDotElemType::E3M2)
.value("E2M1", ScaleDotElemType::E2M1)
.value("BF16", ScaleDotElemType::BF16)
.export_values();

py::class_<MLIRContext>(m, "context", py::module_local())
Expand Down Expand Up @@ -1423,9 +1424,9 @@ void init_triton_ir(py::module &&m) {
})
.def("create_dot_scaled",
[](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale,
F8F6F4Type lhs_format, mlir::Value &rhs,
std::optional<mlir::Value> &rhs_scale, F8F6F4Type rhs_format,
mlir::Value &c) -> mlir::Value {
ScaleDotElemType lhs_format, mlir::Value &rhs,
std::optional<mlir::Value> &rhs_scale,
ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value {
return self.create<DotScaledOp>(
c.getType(), lhs, rhs, c, lhs_scale,
rhs_scale.value_or(Value()), lhs_format, rhs_format);
Expand Down
31 changes: 16 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,7 +3327,7 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
for col_a, col_b in itertools.product([True, False], repeat=2)
for type_a in ["e2m1", "e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2", "bf16"]
for mma in ([32, 16] if is_hip() else [16])
for kpack in ([1, 2] if is_hip() else [1])])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
Expand All @@ -3345,7 +3345,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, type_a: tl.constexpr,
type_b: tl.constexpr):
tl.static_assert(type_b == "e4m3" or type_b == "e5m2", "type_b must be fp8")
tl.static_assert((type_b == "e4m3" or type_b == "e5m2") or type_b == "bf16", "type_b must be fp8 or bf16")
IS_FP8: tl.constexpr = type_a == "e4m3" or type_a == "e5m2"
DIV_FACTOR: tl.constexpr = 1 if IS_FP8 else 2
PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR
Expand Down Expand Up @@ -3436,7 +3436,7 @@ def mxfp_to_bf16_kernel(

def dot_scale_ref(x, scale, y, type_x, type_y):
e_bits, m_bits = {"e2m1": (2, 1), "e4m3": (4, 3), "e5m2": (5, 2)}[type_x]
type_fp8_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2}[type_y]
type_y = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2, "bf16": torch.bfloat16}[type_y]

comp_dtype = torch.bfloat16

Expand All @@ -3449,7 +3449,7 @@ def dot_scale_ref(x, scale, y, type_x, type_y):
mxfp_to_bf16_kernel[grid](x, scale, x_upcast, scale.numel(), e_bits, m_bits, BLOCK_SIZE, num_warps=num_warps)
assert x_upcast.isfinite().all()

y_upcast = y.view(type_fp8_y).to(comp_dtype)
y_upcast = y.view(type_y).to(comp_dtype)

class AccumulateInFp32:

Expand All @@ -3461,28 +3461,30 @@ def __exit__(self, exc_type, exc_val, exc_tb):
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = self.prev_value

with AccumulateInFp32():
return torch.matmul(x_upcast.to(comp_dtype), y_upcast.to(comp_dtype))
return torch.matmul(x_upcast, y_upcast)

torch.manual_seed(0)

def create_uint8(shape, col_major=False, max_val=255):
def make_arg(shape, ty, col_major=False, max_val=255):
if col_major:
shape = shape[:-2] + (shape[-1], shape[-2])
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
if ty == "bf16":
ret = torch.randn(shape, dtype=torch.bfloat16, device=device)
# Clamp to avoid relative error issues
ret.clamp_(-2**15, 2**15 - 1)
else:
ret = torch.randint(max_val + 1, shape, dtype=torch.uint8, device=device)
if col_major:
ret = ret.mT
return ret

DIV_FACTOR = 2 if type_a == "e2m1" else 1
x = create_uint8((M, K // DIV_FACTOR), col_major=col_a)
y = create_uint8((K, N), col_major=col_b)
x = make_arg((M, K // DIV_FACTOR), type_a, col_major=col_a)
y = make_arg((K, N), type_b, col_major=col_b)

# sample scales that don't overflow as otherwise it's implementation defined (underflowing is alright)
# We substract a reasonably high number (64) so that the sum of all the mxfp elements does not overflow
m_bytes = int(type_a[1])
bias_type_a = 1 << (m_bytes - 1) - 1
max_exponent_type_a = (1 << m_bytes) - 1 - bias_type_a
scale_x = create_uint8((M, K // 32), max_val=255 - max_exponent_type_a - 64)
# Max scale= 2**15
scale_x = make_arg((M, K // 32), "e8m0", max_val=127 + 15)

def make_finite(x, dtype):
# e5m2 has too many non-finite values when sampled uniformly (1 / 32) and
Expand All @@ -3507,7 +3509,6 @@ def make_finite(x, dtype):

z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)

# generous rtol as we are sampling the whole range of floats
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)

# make sure ld/st are vectorized
Expand Down
14 changes: 8 additions & 6 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,15 +1555,17 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
lhs and rhs use microscaling formats described here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
:param lhs: The first tensor to be multiplied.
:type lhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
:type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs.
:param lhs_scale: Scale factor for lhs tensor.
:type lhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
:param lhs_format: format of the lhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
:type lhs_scale: e8m0 type represented as an uint8 tensor.
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code: `e5m2`}.
:type lhs_format: str
:param rhs: The second tensor to be multiplied.
:type rhs: 2D tensor of f8, f6 or f4 format packed in int32 format.
:type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs.
:param rhs_scale: Scale factor for rhs tensor.
:type rhs_scale: ue8m0 float8 type (currently represented as an int8 tensor).
:param rhs_format: format of the rhs tensor, available formats: {:code:`e4m3`, :code: `e5m2`, :code:`e2m3`, :code:`e3m2`, :code:`e2m1`}.
:type rhs_scale: e8m0 type represented as an uint8 tensor.
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code: `e5m2`, :code:`bf16`}.
:type rhs_format: str
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
"""
out_dtype = _constexpr_to_value(out_dtype)
Expand Down
49 changes: 32 additions & 17 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,33 +1527,48 @@ def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, input_precision: Optiona
ret_ty)


def _str_to_fp_type(float_format: Optional[str]):
if float_format == 'e4m3':
return ir.F8F6F4TY.E4M3
if float_format == 'e5m2':
return ir.F8F6F4TY.E5M2
if float_format == 'e2m3':
return ir.F8F6F4TY.E2M3
if float_format == 'e3m2':
return ir.F8F6F4TY.E3M2
if float_format == 'e2m1':
return ir.F8F6F4TY.E2M1
raise ValueError(f"Invalid float format: {float_format}.")


def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
rhs_format, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
def _str_to_fp_type(float_format: str):
ty_enum = getattr(ir.ScaleDotElemTypeTY, float_format.upper(), None)
if ty_enum is None:
raise ValueError(f"Invalid float format: {float_format}.")
return ty_enum


def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder):
"""
If float_format is subbyte, make sure it's packed as uint8 and return it.
Otherwise, return a tensor (perhaps bitcasting) of the specified float format.
"""
triton_ty = {"e5m2": tl.float8e5, "e4m3": tl.float8e4nv, "bf16": tl.bfloat16}.get(float_format)
if triton_ty is None:
assert float_format == "e2m1", f"Internal Error: Unexpected float format: {float_format}"
assert val.dtype == tl.uint8, f"e2m1 format must be packed as uint8. Got {val.dtype}"
return val
if val.dtype == triton_ty:
return val
else:
unsigned_ty = {"e5m2": tl.uint8, "e4m3": tl.uint8, "bf16": tl.uint16}[float_format]
assert val.dtype == unsigned_ty, f"Unexpected dtype for {float_format}. Got {val.dtype}"
return bitcast(val, triton_ty, builder)


def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor],
rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
assert lhs.type.is_block() and rhs.type.is_block()
#TODO: validate types.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
lhs_format: str = lhs_format.value
rhs_format: str = rhs_format.value
lhs_format_enum = _str_to_fp_type(lhs_format)
rhs_format_enum = _str_to_fp_type(rhs_format)
assert lhs_format in ("e2m1", "e4m3", "e5m2"), f"NYI: lhs_format {lhs_format}"
assert rhs_format in ("e4m3", "e5m2"), f"NYI: rhs_format {rhs_format}"
assert rhs_format in ("e4m3", "e5m2", "bf16"), f"NYI: rhs_format {rhs_format}"
rhs_scale_is_none = isinstance(rhs_scale, tl.constexpr) and rhs_scale.value is None
assert rhs_scale_is_none, "NYI: rhs_scale not supported"
lhs = _bitcast_to_fp_type(lhs, lhs_format, builder)
rhs = _bitcast_to_fp_type(rhs, rhs_format, builder)

M = lhs.type.shape[-2]
K, N = rhs.type.shape[-2:]
Expand Down
Loading