diff --git a/comfy/ops.py b/comfy/ops.py index 640622fd1854..93b228ce4d41 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -599,6 +599,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, 'block_size': qconfig.get("group_size", None), } if layout_params['scale'] is not None: + layout_params['scale'] = layout_params['scale'].to(device=device) manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( @@ -611,7 +612,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_buffer(param_name, _v.to(device=device)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 1d058bece5a7..d8354cd43411 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -337,6 +337,16 @@ def generic_copy_(func, args, kwargs): return qt_dest return func(*args, **kwargs) +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) @register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) def generic_has_compatible_shallow_copy_type(func, args, kwargs): @@ -383,10 +393,11 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn): scale = torch.tensor(scale) scale = scale.to(device=tensor.device, dtype=torch.float32) - tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype) + tensor_fp32 = tensor.to(torch.float32) + tensor_scaled = tensor_fp32 * (1.0 / scale) # TODO: uncomment this if it's actually needed because the clamp has a small performance penality' - # lp_amax = torch.finfo(dtype).max - # torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format) layout_params = { diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada4af..d7a8e07c761b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ import yaml import math import os +import json import comfy.utils @@ -917,7 +918,20 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + clip_data.append(comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)) + if type(clip_data[-1]) == tuple: + model, metadata = clip_data.pop() + if metadata is not None and "_quantization_metadata" in metadata: + try: + quant_metadata = metadata.pop("_quantization_metadata") + quant_metadata = json.loads(quant_metadata) + if "layers" in quant_metadata: + layer_quant_config = quant_metadata["layers"] + model_options["layer_quant_config"] = layer_quant_config + logging.info(f"Detected quantized text encoder: {len(layer_quant_config)} layers with quantization") + except Exception as e: + logging.warning(f"Failed to parse quantization metadata: {e}") + clip_data.append(model) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7a50..a04ba74b5582 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -109,13 +109,22 @@ def __init__(self, device="cpu", max_length=77, operations = model_options.get("custom_operations", None) scaled_fp8 = None + layer_quant_config = model_options.get("layer_quant_config", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + # Use MixedPrecisionOps if layer_quant_config is present (for FP8 text encoders) + if layer_quant_config is not None: + operations = comfy.ops.MixedPrecisionOps + comfy.ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + comfy.ops.MixedPrecisionOps._compute_dtype = dtype + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations)