Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
return tensors

def dequant_model(self):
if self._is_nvfp4:
return # NVFP4 weights are repacked in _generate_nvfp4_tensors
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
return

tensors_to_remove: list[str] = []
new_tensors: dict[str, Callable[[], Tensor]] = {}
Expand Down Expand Up @@ -474,7 +475,20 @@ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: T
tensors_to_remove.append(base_name + "_zero_point")
else:
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
else:
elif quant_method == "modelopt":
# Mixed-precision ModelOpt models: NVFP4 tensors are handled by
# _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
# are dequantized here. input_scale tensors are unused.
for name in self.model_tensors.keys():
if name.endswith(".weight_scale"):
weight_name = name.removesuffix("_scale")
w = self.model_tensors[weight_name]
s = self.model_tensors[name]
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
tensors_to_remove.append(name)
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
tensors_to_remove.append(name)
elif quant_method is not None:
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")

for name in tensors_to_remove:
Expand Down Expand Up @@ -520,12 +534,6 @@ def set_gguf_parameters(self):
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
if self._is_nvfp4:
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
return []
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
return []

new_name = self.map_tensor_name(name)

Expand Down Expand Up @@ -609,6 +617,7 @@ def _generate_nvfp4_tensors(self):
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
expert_shapes: dict[tuple[int, str], list[int]] = {}
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
consumed: list[str] = []

for name in list(self.model_tensors.keys()):
if not name.endswith(".weight"):
Expand All @@ -620,8 +629,18 @@ def _generate_nvfp4_tensors(self):
# Force eager materialization of lazy tensors
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())

# Skip non-NVFP4 tensors (e.g. FP8 with per-channel 1D scales)
if scale.ndim < 2:
continue

scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())

# Mark tensors for removal from model_tensors (already written to gguf)
consumed.extend([name, scale_name])
if scale2_name in self.model_tensors:
consumed.append(scale2_name)

# Check if this is a per-expert tensor
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
if m:
Expand Down Expand Up @@ -652,6 +671,15 @@ def _generate_nvfp4_tensors(self):
for (bid, proj_type) in list(expert_blocks.keys()):
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)

# Remove consumed tensors so get_tensors/modify_tensors won't see them
for name in consumed:
self.model_tensors.pop(name, None)

# Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
for name in list(self.model_tensors.keys()):
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
del self.model_tensors[name]

def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
experts = expert_blocks.pop(key)
scales = expert_scales.pop(key)
Expand All @@ -677,20 +705,31 @@ def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes,
def prepare_tensors(self):
# detect NVFP4 quantization (ModelOpt format)
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
quant_config_file = self.dir_model / "hf_quant_config.json"

if not quant_algo and quant_config_file.is_file():
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
with open(quant_config_file, "r", encoding="utf-8") as f:
quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
quant_config = json.load(f).get("quantization") or {}
quant_algo = quant_config.get("quant_algo", quant_algo)
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}

self._is_nvfp4 = quant_algo == "NVFP4"
# Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
# per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
if quant_algo != "NVFP4":
if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
quant_algo = "NVFP4"

self.dequant_model()
self._is_nvfp4 = quant_algo == "NVFP4"

# NVFP4 weights are repacked and written directly to gguf_writer
# NVFP4 weights are repacked and written directly to gguf_writer.
# This must run before dequant_model so NVFP4 tensors are removed
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
if self._is_nvfp4:
self._generate_nvfp4_tensors()

self.dequant_model()

# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
Expand Down
Loading