Skip to content

[Bug] TVMError: Data types float16 and float32 must be equal for binary operators #816

@arjunzo

Description

@arjunzo

Used command below:
python -m mlc_llm.build --model ../releases/wm.model.llama-2-7b.32 --target webgpu --quantization q4f32_0

If used --quantization to q4f16_1, it can successfully compile, But in the Browser, it will report an error "error: Extension f16 is not allowed on the Device.", Maybe chrome does not support the type of f16, so I think the q4f32_0 should be of type f32, but the compilation fails.
Am I missing something? How did you successfully run your model on webgpu?

mlc-llm is a great project, I like it so much, I successfully compiled and ran my own model on vulkan.

The throw error is:
Using path "../releases/wm.model.llama-2-7b.32" for model "wm.model.llama-2-7b.32"
Target configured: webgpu -keys=webgpu,gpu -max_num_threads=256
Automatically using target for weight quantization: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -max_shared_memory_per_block=49152 -max_threads_per_block=1024 -registers_per_block=65536 -thread_warp_size=32
Start computing and quantizing weights... This may take a while.
Finish computing and quantizing weights.
Total param size: 3.9250640869140625 GB
Start storing to cache dist/wm.model.llama-2-7b.32-q4f32_0/params
[0327/0327] saving param_326
All finished, 132 total shards committed, record saved to dist/wm.model.llama-2-7b.32-q4f32_0/params/ndarray-cache.json
Finish exporting chat config to dist/wm.model.llama-2-7b.32-q4f32_0/params/mlc-chat-config.json
[23:35:06] /workspace/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/code/mlc-llm/mlc_llm/build.py", line 13, in
main()
File "/code/mlc-llm/mlc_llm/build.py", line 10, in main
core.build_model_from_args(parsed_args)
File "/code/mlc-llm/mlc_llm/core.py", line 575, in build_model_from_args
mod = mod_transform_before_build(mod, param_manager, args, config)
File "/code/mlc-llm/mlc_llm/core.py", line 364, in mod_transform_before_build
mod = fuse_split_rotary_embedding(mod, config["num_attention_heads"], config["hidden_size"])
File "/code/mlc-llm/mlc_llm/transform/fuse_split_rotary_embedding.py", line 177, in fuse_split_rotary_embedding
mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"])
File "/apps/mlc-llm/lib/python3.10/site-packages/tvm/relax/dpl/rewrite.py", line 118, in rewrite_bindings
return ffi.rewrite_bindings(ctx, rewriter, func)
File "tvm/_ffi/_cython/./packed_func.pxi", line 331, in tvm._ffi._cy3.core.PackedFuncBase.call
File "tvm/_ffi/_cython/./packed_func.pxi", line 262, in tvm._ffi._cy3.core.FuncCall
File "tvm/_ffi/_cython/./packed_func.pxi", line 251, in tvm._ffi._cy3.core.FuncCall3
File "tvm/_ffi/_cython/./base.pxi", line 181, in tvm._ffi._cy3.core.CHECK_CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
30: TVMFuncCall
29: ZN3tvm7runtime13PackedFun
28: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function)>::AssignTypedLambda<tvm::relax::Function ()(tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function)>(tvm::relax::Function ()(tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
27: tvm::relax::RewriteBindings(tvm::relax::PatternContext const&, tvm::runtime::PackedFunc, tvm::relax::Function)
26: tvm::relax::Function tvm::relax::PatternRewriter::Runtvm::relax::PatternContext(tvm::relax::PatternContext, tvm::runtime::PackedFunc, tvm::relax::Function)
25: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
24: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
23: ZZN3tvm5relax11ExprFuncto
22: tvm::relax::ExprMutator::VisitExpr
(tvm::relax::FunctionNode const*)
21: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
20: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
19: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
18: ZZN3tvm5relax11ExprFuncto
17: tvm::relax::ExprMutator::VisitExpr
(tvm::relax::SeqExprNode const*)
16: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
15: tvm::relax::PatternRewriter::VisitBindingBlock
(tvm::relax::DataflowBlockNode const*)
14: tvm::relax::PatternRewriter::RewriteDataflowBlockFixedPoint(tvm::relax::BindingBlock)
13: tvm::relax::PatternRewriter::VisitBinding
(tvm::relax::VarBindingNode const*)
12: tvm::relax::ExprMutator::VisitBinding
(tvm::relax::VarBindingNode const*, tvm::relax::DataTypeImmNode const*)
11: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
10: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
9: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
8: ZZN3tvm5relax11ExprFuncto
7: tvm::relax::Normalizer::VisitExpr
(tvm::relax::CallNode const*)
6: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
5: _ZN3tvm7runtime13PackedFun
4: tvm::runtime::TypedPackedFunc<tvm::relax::StructInfo (tvm::relax::Call const&, tvm::relax::BlockBuilder const&)>::AssignTypedLambda<tvm::relax::StructInfo ()(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)>(tvm::relax::StructInfo ()(tvm::relax::Call const&, tvm::relax::BlockBuilder const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
3: tvm::relax::InferStructInfoMatmul(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
2: tvm::relax::InferBinaryArithOpOutDtype(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::TensorStructInfo const&, tvm::relax::TensorStructInfo const&)
1: _ZN3tvm5relax16BlockBuilderImpl11ReportFatalERKNS_1
0: _ZN3tvm7runtime6deta
File "/workspace/tvm/src/relax/ir/block_builder.cc", line 138
TVMError: Data types float16 and float32 must be equal for binary operators

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugConfirmed bugs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions