@@ -272,8 +272,9 @@ def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Call
272272 return tensors
273273
274274 def dequant_model(self):
275- if self._is_nvfp4:
276- return # NVFP4 weights are repacked in _generate_nvfp4_tensors
275+ # If all quantized tensors were already handled (e.g. pure NVFP4), skip
276+ if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
277+ return
277278
278279 tensors_to_remove: list[str] = []
279280 new_tensors: dict[str, Callable[[], Tensor]] = {}
@@ -474,7 +475,20 @@ def dequant_packed(w: Tensor, scale: Tensor, shape_tensor: Tensor, zero_point: T
474475 tensors_to_remove.append(base_name + "_zero_point")
475476 else:
476477 raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
477- else:
478+ elif quant_method == "modelopt":
479+ # Mixed-precision ModelOpt models: NVFP4 tensors are handled by
480+ # _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
481+ # are dequantized here. input_scale tensors are unused.
482+ for name in self.model_tensors.keys():
483+ if name.endswith(".weight_scale"):
484+ weight_name = name.removesuffix("_scale")
485+ w = self.model_tensors[weight_name]
486+ s = self.model_tensors[name]
487+ self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
488+ tensors_to_remove.append(name)
489+ if name.endswith((".input_scale", ".k_scale", ".v_scale")):
490+ tensors_to_remove.append(name)
491+ elif quant_method is not None:
478492 raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
479493
480494 for name in tensors_to_remove:
@@ -520,12 +534,6 @@ def set_gguf_parameters(self):
520534 raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
521535
522536 def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
523- # skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
524- if self._is_nvfp4:
525- if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
526- return []
527- if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
528- return []
529537
530538 new_name = self.map_tensor_name(name)
531539
@@ -609,6 +617,7 @@ def _generate_nvfp4_tensors(self):
609617 expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
610618 expert_shapes: dict[tuple[int, str], list[int]] = {}
611619 n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
620+ consumed: list[str] = []
612621
613622 for name in list(self.model_tensors.keys()):
614623 if not name.endswith(".weight"):
@@ -620,8 +629,18 @@ def _generate_nvfp4_tensors(self):
620629 # Force eager materialization of lazy tensors
621630 weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
622631 scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
632+
633+ # Skip non-NVFP4 tensors (e.g. FP8 with per-channel 1D scales)
634+ if scale.ndim < 2:
635+ continue
636+
623637 scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
624638
639+ # Mark tensors for removal from model_tensors (already written to gguf)
640+ consumed.extend([name, scale_name])
641+ if scale2_name in self.model_tensors:
642+ consumed.append(scale2_name)
643+
625644 # Check if this is a per-expert tensor
626645 m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
627646 if m:
@@ -652,6 +671,15 @@ def _generate_nvfp4_tensors(self):
652671 for (bid, proj_type) in list(expert_blocks.keys()):
653672 self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
654673
674+ # Remove consumed tensors so get_tensors/modify_tensors won't see them
675+ for name in consumed:
676+ self.model_tensors.pop(name, None)
677+
678+ # Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
679+ for name in list(self.model_tensors.keys()):
680+ if name.endswith((".input_scale", ".k_scale", ".v_scale")):
681+ del self.model_tensors[name]
682+
655683 def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
656684 experts = expert_blocks.pop(key)
657685 scales = expert_scales.pop(key)
@@ -677,20 +705,31 @@ def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes,
677705 def prepare_tensors(self):
678706 # detect NVFP4 quantization (ModelOpt format)
679707 quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
708+ quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
680709 quant_config_file = self.dir_model / "hf_quant_config.json"
681710
682- if not quant_algo and quant_config_file.is_file():
711+ if ( not quant_algo or not quant_layers) and quant_config_file.is_file():
683712 with open(quant_config_file, "r", encoding="utf-8") as f:
684- quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
713+ quant_config = json.load(f).get("quantization") or {}
714+ quant_algo = quant_config.get("quant_algo", quant_algo)
715+ quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
685716
686- self._is_nvfp4 = quant_algo == "NVFP4"
717+ # Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
718+ # per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
719+ if quant_algo != "NVFP4":
720+ if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
721+ quant_algo = "NVFP4"
687722
688- self.dequant_model()
723+ self._is_nvfp4 = quant_algo == "NVFP4"
689724
690- # NVFP4 weights are repacked and written directly to gguf_writer
725+ # NVFP4 weights are repacked and written directly to gguf_writer.
726+ # This must run before dequant_model so NVFP4 tensors are removed
727+ # from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
691728 if self._is_nvfp4:
692729 self._generate_nvfp4_tensors()
693730
731+ self.dequant_model()
732+
694733 # Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
695734 if self.tensor_map.mapping:
696735 max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
0 commit comments