diff --git a/examples/eagle/convert_checkpoint.py b/examples/eagle/convert_checkpoint.py index fc7a7b78342..1632e1e2187 100644 --- a/examples/eagle/convert_checkpoint.py +++ b/examples/eagle/convert_checkpoint.py @@ -295,6 +295,14 @@ def copy(tensors): args.n_positions = hf_config.max_position_embeddings args.dtype = str( hf_config.torch_dtype)[6:] if args.dtype == 'auto' else args.dtype + if 'head_dim' in hf_config: + args.head_dim = hf_config.head_dim + else: + args.head_dim = args.n_embd // args.n_head + if 'head_size' in hf_config: + args.head_size = hf_config.head_size + else: + args.head_size = args.head_dim if args.eagle_model_dir is None: hf_config_eagle = hf_config.eagle @@ -305,6 +313,14 @@ def copy(tensors): args.n_kv_head_eagle = hf_config_eagle['num_key_value_heads'] args.rms_norm_eps_eagle = hf_config_eagle['rms_norm_eps'] args.n_positions_eagle = hf_config_eagle['max_position_embeddings'] + if 'head_dim' in hf_config_eagle: + args.head_dim_eagle = hf_config_eagle['head_dim'] + else: + args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle + if 'head_size' in hf_config_eagle: + args.head_size_eagle = hf_config_eagle['head_size'] + else: + args.head_size_eagle = args.head_dim_eagle else: hf_config_eagle = LlamaConfig.from_pretrained(args.eagle_model_dir) args.n_head_eagle = hf_config_eagle.num_attention_heads @@ -314,6 +330,14 @@ def copy(tensors): args.n_kv_head_eagle = hf_config_eagle.num_key_value_heads args.rms_norm_eps_eagle = hf_config_eagle.rms_norm_eps args.n_positions_eagle = hf_config_eagle.max_position_embeddings + if 'head_dim' in hf_config_eagle: + args.head_dim_eagle = hf_config_eagle.head_dim + else: + args.head_dim_eagle = args.n_embd_eagle // args.n_head_eagle + if 'head_size' in hf_config_eagle: + args.head_size_eagle = hf_config_eagle.head_size + else: + args.head_size_eagle = args.head_dim_eagle elif args.meta_ckpt_dir is not None: assert False, "meta ckpt is not supported yet" @@ -370,6 +394,8 @@ def copy(tensors): }, 'use_parallel_embedding': args.use_parallel_embedding, 'embedding_sharding_dim': args.embedding_sharding_dim, + 'head_dim': args.head_dim_eagle, + 'head_size': args.head_size_eagle } config = { @@ -402,7 +428,9 @@ def copy(tensors): 'max_draft_len': args.max_draft_len, 'num_eagle_layers': args.num_eagle_layers, 'max_non_leaves_per_layer': args.max_non_leaves_per_layer, - 'eagle_net_config': eagle_net_config + 'eagle_net_config': eagle_net_config, + 'head_dim': args.head_dim, + 'head_size': args.head_size } assert args.max_draft_len <= 256, "args.max_draft_len > 256 is not supported" diff --git a/tensorrt_llm/models/eagle/config.py b/tensorrt_llm/models/eagle/config.py index 6c434441c97..6124884f723 100644 --- a/tensorrt_llm/models/eagle/config.py +++ b/tensorrt_llm/models/eagle/config.py @@ -88,6 +88,14 @@ def from_hugging_face( n_positions = hf_config.max_position_embeddings hidden_act = hf_config.hidden_act dtype = str(hf_config.torch_dtype)[6:] if dtype == 'auto' else dtype + if hasattr(hf_config, 'head_dim'): + head_dim = hf_config.head_dim + else: + head_dim = hf_config.n_embd // hf_config.n_head + if hasattr(hf_config, 'head_size'): + head_size = hf_config.head_size + else: + head_size = head_dim if speculative_config_or_dir is None: hf_config_eagle = hf_config.eagle @@ -143,6 +151,8 @@ def from_hugging_face( }, 'use_parallel_embedding': kwargs['use_parallel_embedding'], 'embedding_sharding_dim': kwargs['embedding_sharding_dim'], + 'head_dim': head_dim, + 'head_size': head_size } config = { diff --git a/tests/integration/defs/common.py b/tests/integration/defs/common.py index cdeaf6eb8ac..5b66160b57b 100644 --- a/tests/integration/defs/common.py +++ b/tests/integration/defs/common.py @@ -943,11 +943,22 @@ def get_dummy_spec_decoding_heads(hf_model_dir, ) quant_cfg = getattr(mtq, "FP8_DEFAULT_CFG") + # Following quantizers are needed for KV cache quantization. quant_cfg["quant_cfg"]["*output_quantizer"] = { "num_bits": (4, 3), "axis": None, "enable": True, } + quant_cfg["quant_cfg"]["*k_bmm_quantizer"] = { + "num_bits": (4, 3), + "axis": None, + "enable": True, + } + quant_cfg["quant_cfg"]["*v_bmm_quantizer"] = { + "num_bits": (4, 3), + "axis": None, + "enable": True, + } calibrate_loop = dataset_utils.create_forward_loop( calib_dataloader, dataloader=calib_dataloader) diff --git a/tests/integration/defs/examples/test_eagle.py b/tests/integration/defs/examples/test_eagle.py index f8bbd578048..6066c349ef8 100644 --- a/tests/integration/defs/examples/test_eagle.py +++ b/tests/integration/defs/examples/test_eagle.py @@ -270,17 +270,6 @@ def test_codellama_eagle_1gpu(code_llama_model_root, llm_datasets_root=llm_datasets_root, llm_rouge_root=llm_rouge_root) - test_with_dummy_eagle(hf_model_root=code_llama_model_root, - eagle_example_root=eagle_example_root, - llm_venv=llm_venv, - cmodel_dir=cmodel_dir, - engine_dir=engine_dir, - batch_size=batch_size, - data_type=data_type, - use_dynamic_tree=use_dynamic_tree, - llm_datasets_root=llm_datasets_root, - llm_rouge_root=llm_rouge_root) - @pytest.mark.parametrize("use_dynamic_tree", [False, True], ids=['eagle1', 'eagle2']) @@ -309,6 +298,33 @@ def test_mistral_eagle_1gpu(llm_mistral_model_root, llm_rouge_root=llm_rouge_root) +@pytest.mark.parametrize("use_dynamic_tree", [False, True], + ids=['eagle1', 'eagle2']) +@pytest.mark.parametrize("mistral_nemo_model_root", ['Mistral-Nemo-12b-Base'], + indirect=True) +def test_mistral_nemo_eagle_1gpu(mistral_nemo_model_root, + eagle_example_root, + llm_datasets_root, + llm_rouge_root, + llm_venv, + cmodel_dir, + engine_dir, + use_dynamic_tree, + batch_size=8, + data_type='bfloat16'): + + test_with_dummy_eagle(hf_model_root=mistral_nemo_model_root, + eagle_example_root=eagle_example_root, + llm_venv=llm_venv, + cmodel_dir=cmodel_dir, + engine_dir=engine_dir, + batch_size=batch_size, + data_type=data_type, + use_dynamic_tree=use_dynamic_tree, + llm_datasets_root=llm_datasets_root, + llm_rouge_root=llm_rouge_root) + + @pytest.mark.parametrize("use_dynamic_tree", [False, True], ids=['eagle1', 'eagle2']) @pytest.mark.parametrize("llm_qwen_model_root", [ diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index f392317cccc..604161f69a0 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -500,6 +500,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle1] examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle1] examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle1] examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle1] +examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle1] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle1] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle1] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle1] @@ -514,6 +515,7 @@ examples/test_eagle.py::test_llama_eagle_1gpu[llama-v2-7b-hf-eagle2] examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.2-1b-eagle2] examples/test_eagle.py::test_llama_eagle_1gpu[llama-3.1-8b-eagle2] examples/test_eagle.py::test_mistral_eagle_1gpu[mistral-7b-v0.1-eagle2] +examples/test_eagle.py::test_mistral_nemo_eagle_1gpu[Mistral-Nemo-12b-Base-eagle2] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle2] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle2] examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle2]