Skip to content

[Bug] error if run fp8 quantization, Check failed: (value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) is false: Bitcast requires size match uint8 vs float16 #3151

@Howave

Description

@Howave

🐛 Bug

To Reproduce

Steps to reproduce the behavior:
Llama-2-7b-chat-hf

  1. python3 -m mlc_llm gen_config $model_dir --quantization e4m3_e4m3_f16_max_calibrate --conv-template llama-2 -o $out_dir
  2. convert weights

python3 -m mlc_llm convert_weight $model_dir --device cuda --quantization e4m3_e4m3_f16_max_calibrate -o $out_dir

-- If you have a code sample, error messages, stack traces, please provide it here as well --

2025-03-04 14:18:53,084 [INFO] [huggingface_loader.py:187] Loading HF parameters from: /share-global/device-engineering/ai_compiler/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin
  0%|                                                                                                                                                                       | 0/195 [00:00<?, ?it/s]/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/loader/utils.py:43: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  for name, param in torch.load(path, map_location=torch.device("cpu")).items():
lm_head.weight [[-0.003601  0.002655 -0.007385 ...  0.003937 -0.00842   0.00647 ]
 [-0.03113   0.04492  -0.00293  ... -0.02283   0.01471   0.03198 ]
 [-0.01245   0.001419  0.0188   ... -0.02637   0.015564 -0.007263]
 ...
 [-0.02942  -0.01721  -0.002853 ...  0.01404  -0.0116   -0.02344 ]
 [ 0.02039   0.02393   0.02722  ...  0.00479  -0.009705 -0.00641 ]
 [ 0.00806  -0.005737  0.00824  ... -0.0282   -0.01636   0.03113 ]]
2025-03-04 14:19:11,413 [INFO] [huggingface_loader.py:177] [Not quantized] Parameter: "lm_head.weight", shape: (32000, 4096), dtype: float16
model.layers.24.input_layernorm.weight [0.4902 0.5156 0.504  ... 0.4824 0.508  0.5   ]
2025-03-04 14:19:13,304 [INFO] [huggingface_loader.py:177] [Not quantized] Parameter: "model.layers.24.input_layernorm.weight", shape: (4096,), dtype: float16
model.layers.24.mlp.down_proj.weight [[-4.8218e-03 -1.2695e-02 -3.2715e-02 ... -7.2956e-05  4.2114e-03
  -1.9653e-02]
 [ 2.7588e-02 -6.9885e-03  1.6357e-02 ... -3.1433e-03  7.9956e-03
  -4.6875e-02]
 [ 2.0264e-02  1.7700e-02 -3.1738e-02 ...  1.3580e-03  2.1973e-02
   6.5002e-03]
 ...
 [-7.1411e-03  9.2773e-03 -1.8799e-02 ... -3.3447e-02 -1.2573e-02
  -2.3193e-02]
 [ 1.0376e-02  2.5513e-02  6.1035e-03 ...  2.2949e-02 -2.6123e-02
   1.8066e-02]
 [-2.1851e-02 -7.3853e-03  5.4321e-03 ...  1.2268e-02 -6.9885e-03
  -3.1494e-02]]
2025-03-04 14:19:13,360 [INFO] [per_tensor_quantization.py:213] Compiling quantize function for key: ((4096, 11008), float16, cuda
e4m3_float8 e4m3_float8
elem_storage_dtype: e4m3_float8
self.storage_dtype: e4m3_float8
  1%|█▋                                                                                                                                                             | 2/195 [00:21<33:53, 10.53s/it]
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/__main__.py", line 64, in <module>
    main()
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/__main__.py", line 37, in main
    cli.main(sys.argv[2:])
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/cli/convert_weight.py", line 88, in main
    convert_weight(
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/interface/convert_weight.py", line 181, in convert_weight
    _convert_args(args)
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/interface/convert_weight.py", line 145, in _convert_args
    tvmjs.dump_ndarray_cache(
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/contrib/tvmjs.py", line 273, in dump_ndarray_cache
    for k, origin_v in param_generator:
                       ^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/interface/convert_weight.py", line 129, in _param_generator
    for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/loader/huggingface_loader.py", line 121, in load
    for name, loader_param in self._load_or_quantize(mlc_name, param, device):
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/loader/huggingface_loader.py", line 165, in _load_or_quantize
    q_params = self.quantize_param_map.map_func[mlc_name](param)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/quantization/per_tensor_quantization.py", line 214, in quantize_weight
    quantize_func = compile_quantize_func(_create_quantize_func(), device)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/python/mlc_llm/quantization/utils.py", line 80, in compile_quantize_func
    ex = relax.build(mod, target=target)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/relax/vm_build.py", line 400, in build
    return _vmlink(
           ^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/relax/vm_build.py", line 261, in _vmlink
    lib = tvm.build(
          ^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/driver/build_module.py", line 423, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/python/tvm/_ffi/base.py", line 482, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  32: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  31: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  30: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  29: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  28: tvm::transform::Pass::operator()(tvm::IRModule) const
  27: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  26: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  25: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  24: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: _ZN3tvm7runtime13PackedFuncObj
  22: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  21: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
  20: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  19: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  18: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::DeclBufferNode const*)
  17: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  16: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  15: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::DeclBufferNode const*)
  14: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  13: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  12: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::DeclBufferNode const*)
  11: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  10: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  9: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  8: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  6: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  4: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime
  3: tvm::tir::StorageLegalizer::VisitStmt_(tvm::tir::BufferStoreNode const*)
  2: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10InitVTableEvENUlRKNS_7runt
  1: tvm::tir::StorageLegalizer::VisitExpr_(tvm::tir::CallNode const*)
  0: tvm::reinterpret(tvm::runtime::DataType const&, tvm::PrimExpr, tvm::Span)
  File "/home/howell.zeng/work/laser/mlc-llm/3rdparty/tvm/src/tir/op/op.cc", line 348
LaserError:
---------------------------------------------------------------
An error occurred during the execution of Laser.
---------------------------------------------------------------
  Check failed: (value.dtype().bits() * value.dtype().lanes() == t.bits() * t.lanes()) is false: Bitcast requires size match uint8 vs float16

Expected behavior

Environment

  • Platform: A100 CUDA
  • Operating system: Ubuntu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugConfirmed bugs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions