Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __new__(cls, qdata, layout_type, layout_params):
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)

def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous()
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params

Expand Down Expand Up @@ -411,7 +411,7 @@ def fp8_linear(func, args, kwargs):

try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]),
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
Expand Down Expand Up @@ -447,6 +447,43 @@ def fp8_linear(func, args, kwargs):
return torch.nn.functional.linear(input_tensor, weight, bias)


@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]

if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']

plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)

output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)

if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output

a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()

return func(*a, **kwargs)

@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
Expand Down