Conversation
This refactoring centralizes model-specific configurations within the provider_bridge method of each model bridge. Changes: - Add MoE-related field mappings to base class CONFIG_MAPPING: - num_experts -> num_moe_experts - num_experts_per_tok -> moe_router_topk - moe_intermediate_size -> moe_ffn_hidden_size - Refactor LlamaBridge: - Use MEGATRON_DEFAULTS and HF_DEFAULTS class attributes - Override provider_bridge only for RoPE scaling (Llama 3.1/3.2) - Refactor Qwen2Bridge: - Use MEGATRON_DEFAULTS (add_qkv_bias=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3Bridge: - Use MEGATRON_DEFAULTS (qk_layernorm=True) and HF_DEFAULTS - No provider_bridge override needed - Refactor Qwen3MoEBridge: - Use MEGATRON_DEFAULTS with MoE settings and HF_DEFAULTS - No provider_bridge override needed - Update tests to expect GPTModelProvider instead of model-specific providers - Add verification scripts for both Llama and Qwen bridges Verified on remote server: - Qwen/Qwen2-0.5B: PASS - Qwen/Qwen2-7B: PASS - Qwen/Qwen3-0.6B: PASS - Qwen/Qwen3-1.7B: PASS - Qwen/Qwen3-30B-A3B: PASS
…dels - Add MLAModelProvider as unified base for Multi-Latent Attention models - Refactor DeepSeek V2/V3 bridges to use MLAModelProvider - Refactor Kimi K2 bridge to use MLAModelProvider - Move model-specific defaults from providers to MEGATRON_DEFAULTS in bridges - Add model_type parameter to @register_bridge decorator for auto HF config - Simplify provider files to deprecated backward-compatible aliases Verified: DeepSeek-V2-Lite, DeepSeek-V2, DeepSeek-V3, Moonlight-16B, Kimi-K2
- Register GemmaModelProvider, Gemma2ModelProvider, Gemma3ModelProvider via decorator
- Add MEGATRON_DEFAULTS to Gemma/Gemma2 bridges for explicit config defaults
- Add gelu_pytorch_tanh -> fast_gelu to ACTIVATION_MAPPING in model_bridge.py
- Add verification script for Gemma provider refactoring
Verified: gemma-2b, gemma-7b, gemma-2-2b, gemma-2-9b, gemma-2-27b,
gemma-3-4b-it, gemma-3-12b-it, gemma-3-27b-it
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
OLMoE HF config doesn't have head_dim attribute, so kv_channels was left as None. This fix calculates it as hidden_size // num_attention_heads (2048 // 16 = 128 for OLMoE-1B-7B). This follows the pattern used by MistralBridge and NemotronHBridge.
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
📝 WalkthroughWalkthroughThis PR introduces bidirectional HuggingFace↔Megatron configuration translation support in MegatronModelBridge, refactors model bridges to use unified provider classes (GPTModelProvider/MLAModelProvider), removes generation_config from providers, and removes provider configuration equivalence tests. Changes
Sequence Diagram(s)sequenceDiagram
participant HF as HuggingFace Config
participant Bridge as MegatronModelBridge
participant Translator as Config Translators
participant Provider as Provider (GPT/MLA)
HF->>Bridge: hf_config_to_provider_kwargs(hf_config)
activate Bridge
Bridge->>Translator: Map HF fields via CONFIG_MAPPING
Translator-->>Bridge: Basic config dict
Bridge->>Translator: Convert activations via hf_to_megatron_activation
Translator-->>Bridge: Mapped activation functions
Bridge->>Translator: Handle rope scaling (YARN/MLA)
Translator-->>Bridge: Rope scaling params
Bridge-->>Provider: provider_kwargs dict
deactivate Bridge
Provider->>Provider: Instantiate with translated kwargs
Provider->>Bridge: megatron_to_hf_config(provider)
activate Bridge
Bridge->>Translator: Extract rope scaling from provider
Translator-->>Bridge: Rope scaling config
Bridge->>Translator: Convert activation via megatron_to_hf_activation
Translator-->>Bridge: HF activation name
Bridge->>Translator: Reverse map config fields
Translator-->>Bridge: HF config dict
Bridge-->>HF: Complete HF config
deactivate Bridge
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 1 | ❌ 3❌ Failed checks (3 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Fix all issues with AI agents
In `@src/megatron/bridge/models/conversion/model_bridge.py`:
- Around line 388-395: The yarn_params dict is populated with entries from
YARN_ROPE_SCALING_MAPPING even when rope_scaling.get(hf_key) is None, so filter
out None values: when iterating over self.YARN_ROPE_SCALING_MAPPING only assign
megatron_key -> value into yarn_params if value is not None (leave
position_embedding_type and the "truncate" -> yarn_correction_range_round_to_int
mapping as-is), then attach provider_kwargs["_yarn_params"] only when
yarn_params contains at least the position_embedding_type or other non-None
entries.
In `@src/megatron/bridge/models/gemma/gemma3_bridge.py`:
- Around line 63-67: The code assigns provider.rope_scaling_factor using
hf_config.rope_scaling["factor"] which will raise a KeyError if rope_scaling is
a dict missing "factor"; update the assignment in gemma3_bridge.py to
defensively access the key (e.g., use hf_config.rope_scaling.get("factor", 1.0)
or check for the key/None first) so provider.rope_scaling_factor defaults to 1.0
when hf_config.rope_scaling is falsy or lacks "factor".
In `@src/megatron/bridge/models/gpt_provider.py`:
- Around line 139-144: The config class currently defines an unused field
rotary_scaling_factor while only rope_scaling_factor is passed into
MCoreGPTModel; either remove rotary_scaling_factor from the class to avoid
dead/unused configuration, or retain it but add a clear comment above the field
explaining its intended purpose and when subclasses (e.g., DeepSeek/Kimi) should
override it for MLA-specific behavior; update references in the class docstring
and ensure MCoreGPTModel instantiation (where rope_scaling_factor is passed)
remains consistent with the chosen approach.
In `@tests/unit_tests/models/llama/test_llama_bridge.py`:
- Around line 386-390: Rename the misnamed test method to include the missing
underscore so pytest discovers it: change the function name
testhf_config_to_provider_kwargs_from_base_class to
test_hf_config_to_provider_kwargs_from_base_class; ensure the assertions still
reference LlamaBridge and MegatronModelBridge and the
hf_config_to_provider_kwargs attribute to verify the inherited instance method.
- Around line 398-430: Rename the test function
testhf_config_to_provider_kwargs_returns_correct_mappings to
test_hf_config_to_provider_kwargs_returns_correct_mappings so pytest will
discover it; locate the method in
tests/unit_tests/models/llama/test_llama_bridge.py (the current def
testhf_config_to_provider_kwargs_returns_correct_mappings) and simply insert the
missing underscore in the function name while leaving the body and the call to
LlamaBridge.hf_config_to_provider_kwargs unchanged.
🧹 Nitpick comments (20)
tests/unit_tests/models/olmoe/test_olmoe_bridge.py (2)
484-485: Simplify boolean comparison.Per ruff E712, avoid equality comparisons to
True. Use the boolean value directly.Proposed fix
- assert AutoBridge.supports(config) == True + assert AutoBridge.supports(config)
488-489: Same simplification for False comparison.Proposed fix
- assert AutoBridge.supports(non_causal_config) == False + assert not AutoBridge.supports(non_causal_config)tests/unit_tests/models/qwen/test_qwen3_bridge.py (2)
455-455: Simplify boolean comparison.Same pattern as other test files - avoid
== True.Proposed fix
- assert AutoBridge.supports(config) == True + assert AutoBridge.supports(config)
460-460: Simplify False comparison.Proposed fix
- assert AutoBridge.supports(non_causal_config) == False + assert not AutoBridge.supports(non_causal_config)src/megatron/bridge/models/qwen/qwen3_moe_bridge.py (1)
43-45: Add explicit type hints to provider_bridge.The method lost its argument/return type annotations.
🛠️ Suggested fix
from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, GatedMLPMapping, QKVMapping, ) +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM @@ - def provider_bridge(self, hf_pretrained): + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider:As per coding guidelines: Use type hints for function arguments and return types.
tests/unit_tests/models/qwen/test_qwen3_moe_bridge.py (1)
94-99: Docstring mismatch: spec is[], notlist.Align the comment with the actual Mock spec.
✏️ Suggested fix
- Uses spec=list to make getattr return None for undefined attributes + Uses spec=[] to make getattr return None for undefined attributes instead of Mock objects, which would incorrectly be passed to GPTModelProvider.src/megatron/bridge/models/qwen/qwen3_bridge.py (1)
43-45: Restore type hints on provider_bridge.The method should keep explicit argument/return types.
🛠️ Suggested fix
from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, GatedMLPMapping, QKVMapping, ) +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM @@ - def provider_bridge(self, hf_pretrained): + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider:As per coding guidelines: Use type hints for function arguments and return types.
src/megatron/bridge/models/deepseek/deepseek_provider.py (1)
16-16: UseT | Noneinstead of Optional[T].Convert Optional annotations to the Python 3.10+ union syntax.
🛠️ Suggested fix
-from typing import Callable, List, Optional, Union +from typing import Callable, List, Union @@ - q_lora_rank: Optional[int] = 1536 + q_lora_rank: int | None = 1536 @@ - q_lora_rank: Optional[int] = None + q_lora_rank: int | None = NoneAs per coding guidelines: Use 'T | None' for nullable types instead of 'Optional[T]'.
Also applies to: 57-58, 104-105
src/megatron/bridge/models/gemma/gemma2_bridge.py (1)
46-48: Use Google-style docstring here.Please add Args/Returns so the method docstring is Sphinx-parseable.
As per coding guidelines: Use Google style docstrings (parseable by Sphinx) for classes and functions.Proposed docstring shape
- """Convert HuggingFace config to Gemma2ModelProvider.""" + """Convert HuggingFace config to Gemma2ModelProvider. + + Args: + hf_pretrained: HuggingFace pretrained model wrapper. + + Returns: + Gemma2ModelProvider: Configured provider for Megatron model. + """src/megatron/bridge/models/llama/llama_bridge.py (1)
31-31: Rename global logger to follow G_ convention.*Global module variables should use the
G_prefix.As per coding guidelines: Use upper snake_case and prefix 'G' for global variables (e.g., G_MY_GLOBAL).Suggested rename
-logger = logging.getLogger(__name__) +G_LOGGER = logging.getLogger(__name__)src/megatron/bridge/models/nemotron/nemotron_bridge.py (2)
29-31: Add type hints + Google-style docstring forsquared_relu.This keeps helper utilities consistent with the repo’s typing and docstring conventions.
As per coding guidelines: Use type hints for function arguments and return types; Use Google style docstrings (parseable by Sphinx) for classes and functions.Proposed update
-def squared_relu(x): - """Squared ReLU activation function.""" - return torch.pow(torch.nn.functional.relu(x), 2) +def squared_relu(x: torch.Tensor) -> torch.Tensor: + """Squared ReLU activation function. + + Args: + x: Input tensor. + + Returns: + Output tensor after applying squared ReLU. + """ + return torch.pow(torch.nn.functional.relu(x), 2)
52-55: Optional: use iterable unpacking forCONFIG_MAPPING.This avoids list concatenation and aligns with ruff’s style suggestion.
Example change
- CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [ - # Nemotron uses norm_eps instead of rms_norm_eps - ("norm_eps", "layernorm_epsilon"), - ] + CONFIG_MAPPING = [ + *MegatronModelBridge.CONFIG_MAPPING, + # Nemotron uses norm_eps instead of rms_norm_eps + ("norm_eps", "layernorm_epsilon"), + ]src/megatron/bridge/models/qwen/qwen2_bridge.py (1)
46-58: Consider adding type hints for consistency.Other bridges like
GemmaBridgeinclude type hints onprovider_bridge. Adding them here would improve IDE support and documentation.- def provider_bridge(self, hf_pretrained): - """Convert HuggingFace Qwen2 config to GPTModelProvider.""" + def provider_bridge(self, hf_pretrained: "PreTrainedCausalLM") -> "GPTModelProvider": + """Convert HuggingFace Qwen2 config to GPTModelProvider."""src/megatron/bridge/models/glm/glm45_bridge.py (1)
34-39: Remove unusednoqadirective.Per static analysis, the
# noqa: F401directive is unnecessary since the F401 rule is not enabled in the ruff configuration.Suggested fix
try: - import transformer_engine # noqa: F401 + import transformer_engine HAVE_TE = True except (ImportError, ModuleNotFoundError): HAVE_TE = Falsesrc/megatron/bridge/models/deepseek/deepseek_v2_bridge.py (1)
27-32: Remove unusednoqadirective.Same issue as in
glm45_bridge.py- the# noqa: F401directive is unnecessary.Suggested fix
try: - import transformer_engine # noqa: F401 + import transformer_engine HAVE_TE = True except (ImportError, ModuleNotFoundError): HAVE_TE = Falsesrc/megatron/bridge/models/nemotronh/nemotron_h_bridge.py (1)
57-70: Consider using iterable unpacking for CONFIG_MAPPING.Per static analysis (RUF005), unpacking is preferred over concatenation for better readability and slight performance improvement.
Suggested refactor
- CONFIG_MAPPING = MegatronModelBridge.CONFIG_MAPPING + [ + CONFIG_MAPPING = [ + *MegatronModelBridge.CONFIG_MAPPING, # Mamba-specific fields ("mamba_head_dim", "mamba_head_dim"), ("mamba_num_heads", "mamba_num_heads"), ("n_groups", "mamba_num_groups"), ("ssm_state_size", "mamba_state_dim"), ("hybrid_override_pattern", "hybrid_override_pattern"), ("residual_in_fp32", "fp32_residual_connection"), ("use_bias", "add_bias_linear"), ("layer_norm_epsilon", "layernorm_epsilon"), # MoE-specific fields (already in base but with different HF names) ("moe_shared_expert_intermediate_size", "moe_shared_expert_intermediate_size"), ]src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py (2)
31-31: Remove unusednoqadirective.Static analysis indicates the
# noqa: F401is unnecessary since no F401 rule violation exists here. The import is used for the side-effect check.Proposed fix
- import transformer_engine # noqa: F401 + import transformer_engine
47-94: Significant code duplication with DeepSeek V2 bridge.Comparing with
deepseek_v2_bridge.py(from relevant snippets), approximately 40+ lines inprovider_bridgeare identical between V2 and V3 (lines 51-84 mirror the V2 implementation). The only differences appear to be:
- V3 adds
moe_router_enable_expert_bias = True(line 65)- V3 adds
moe_shared_expert_overlap = True(line 64)Consider extracting common DeepSeek provider configuration to a shared helper (e.g., in
common.py) to reduce duplication and ensure consistency.src/megatron/bridge/models/conversion/model_bridge.py (2)
244-285: Annotate mutable class attributes withClassVar.Per static analysis (RUF012), mutable class attributes like
CONFIG_MAPPINGshould be annotated withtyping.ClassVarto indicate they are class-level rather than instance-level attributes.Proposed fix
+from typing import ClassVar + class MegatronModelBridge(...): ... - CONFIG_MAPPING = [ + CONFIG_MAPPING: ClassVar[list[tuple[str, str]]] = [ # Core architecture ("num_hidden_layers", "num_layers"), ... ]Apply similar annotations to
YARN_ROPE_SCALING_MAPPING,MLA_ROPE_SCALING_MAPPING, andACTIVATION_MAPPING.
368-368: Local import to avoid circular dependency.The import of
MLAModelProviderinsidehf_config_to_provider_kwargsis likely to prevent circular imports. Consider adding a brief comment explaining this.Proposed documentation
+ # Import here to avoid circular dependency (MLAModelProvider imports from this module) from megatron.bridge.models.mla_provider import MLAModelProvider
| # Gemma3-specific features not in CONFIG_MAPPING | ||
| provider.window_size = hf_config.sliding_window | ||
| provider.rotary_base = (hf_config.rope_local_base_freq, hf_config.rope_theta) | ||
| provider.softmax_scale = 1.0 / math.sqrt(hf_config.query_pre_attn_scalar) | ||
| provider.rope_scaling_factor = hf_config.rope_scaling["factor"] if hf_config.rope_scaling else 1.0 |
There was a problem hiding this comment.
Consider defensive access for rope_scaling["factor"].
If hf_config.rope_scaling is a dict but doesn't contain the "factor" key, this will raise a KeyError. Consider using .get() with a default value.
Suggested fix
- provider.rope_scaling_factor = hf_config.rope_scaling["factor"] if hf_config.rope_scaling else 1.0
+ provider.rope_scaling_factor = hf_config.rope_scaling.get("factor", 1.0) if hf_config.rope_scaling else 1.0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Gemma3-specific features not in CONFIG_MAPPING | |
| provider.window_size = hf_config.sliding_window | |
| provider.rotary_base = (hf_config.rope_local_base_freq, hf_config.rope_theta) | |
| provider.softmax_scale = 1.0 / math.sqrt(hf_config.query_pre_attn_scalar) | |
| provider.rope_scaling_factor = hf_config.rope_scaling["factor"] if hf_config.rope_scaling else 1.0 | |
| # Gemma3-specific features not in CONFIG_MAPPING | |
| provider.window_size = hf_config.sliding_window | |
| provider.rotary_base = (hf_config.rope_local_base_freq, hf_config.rope_theta) | |
| provider.softmax_scale = 1.0 / math.sqrt(hf_config.query_pre_attn_scalar) | |
| provider.rope_scaling_factor = hf_config.rope_scaling.get("factor", 1.0) if hf_config.rope_scaling else 1.0 |
🤖 Prompt for AI Agents
In `@src/megatron/bridge/models/gemma/gemma3_bridge.py` around lines 63 - 67, The
code assigns provider.rope_scaling_factor using hf_config.rope_scaling["factor"]
which will raise a KeyError if rope_scaling is a dict missing "factor"; update
the assignment in gemma3_bridge.py to defensively access the key (e.g., use
hf_config.rope_scaling.get("factor", 1.0) or check for the key/None first) so
provider.rope_scaling_factor defaults to 1.0 when hf_config.rope_scaling is
falsy or lacks "factor".
| def test_provider_bridge_backward_compatibility(self, mock_pretrained_llama): | ||
| """Test that provider_bridge still works as an alias for provider_bridge.""" | ||
| bridge = LlamaBridge() | ||
|
|
||
| # Pass model only | ||
| # Both methods should return equivalent results | ||
| result_config = bridge.provider_bridge(mock_pretrained_llama) | ||
| result_provider = bridge.provider_bridge(mock_pretrained_llama) | ||
|
|
||
| # They should have the same architecture | ||
| assert result_config.num_layers == result_provider.num_layers | ||
| assert result_config.hidden_size == result_provider.hidden_size | ||
| assert result_config.normalization == result_provider.normalization |
There was a problem hiding this comment.
Test is a tautology - calls same method twice.
This test calls bridge.provider_bridge(mock_pretrained_llama) twice and compares the results. Since both calls create new provider instances, the comparison is always true for equivalent configurations. The docstring says it tests "provider_bridge still works as an alias for provider_bridge" which doesn't make sense.
Consider either removing this test or clarifying its intent (e.g., testing idempotency or determinism).
Suggested clarification
def test_provider_bridge_backward_compatibility(self, mock_pretrained_llama):
- """Test that provider_bridge still works as an alias for provider_bridge."""
+ """Test that provider_bridge produces consistent results across multiple calls."""
bridge = LlamaBridge()
- # Both methods should return equivalent results
+ # Multiple calls should return equivalent results (determinism check)
result_config = bridge.provider_bridge(mock_pretrained_llama)
result_provider = bridge.provider_bridge(mock_pretrained_llama)
- # They should have the same architecture
+ # They should have the same architecture (consistency check)
assert result_config.num_layers == result_provider.num_layers| def testhf_config_to_provider_kwargs_from_base_class(self): | ||
| """Test that hf_config_to_provider_kwargs is inherited from MegatronModelBridge.""" | ||
| assert hasattr(LlamaBridge, "hf_config_to_provider_kwargs") | ||
| # It's an instance method, check it exists on the class | ||
| assert hasattr(MegatronModelBridge, "hf_config_to_provider_kwargs") |
There was a problem hiding this comment.
Fix test method naming - missing underscore.
The method name testhf_config_to_provider_kwargs_from_base_class is missing an underscore after test, which means pytest won't discover it as a test.
Proposed fix
- def testhf_config_to_provider_kwargs_from_base_class(self):
+ def test_hf_config_to_provider_kwargs_from_base_class(self):🤖 Prompt for AI Agents
In `@tests/unit_tests/models/llama/test_llama_bridge.py` around lines 386 - 390,
Rename the misnamed test method to include the missing underscore so pytest
discovers it: change the function name
testhf_config_to_provider_kwargs_from_base_class to
test_hf_config_to_provider_kwargs_from_base_class; ensure the assertions still
reference LlamaBridge and MegatronModelBridge and the
hf_config_to_provider_kwargs attribute to verify the inherited instance method.
| def testhf_config_to_provider_kwargs_returns_correct_mappings(self): | ||
| """Test that hf_config_to_provider_kwargs correctly maps HF config to provider kwargs.""" | ||
| # Create a mock HF config | ||
| mock_hf_config = Mock() | ||
| mock_hf_config.num_hidden_layers = 32 | ||
| mock_hf_config.hidden_size = 4096 | ||
| mock_hf_config.intermediate_size = 14336 | ||
| mock_hf_config.num_attention_heads = 32 | ||
| mock_hf_config.num_key_value_heads = 8 | ||
| mock_hf_config.vocab_size = 128256 | ||
| mock_hf_config.max_position_embeddings = 8192 | ||
| mock_hf_config.rope_theta = 500000.0 | ||
| mock_hf_config.rms_norm_eps = 1e-05 | ||
| mock_hf_config.initializer_range = 0.02 | ||
| mock_hf_config.hidden_act = "silu" | ||
| mock_hf_config.torch_dtype = "bfloat16" | ||
| mock_hf_config.attention_dropout = 0.0 | ||
| mock_hf_config.tie_word_embeddings = False | ||
| mock_hf_config.attention_bias = False | ||
| mock_hf_config.mlp_bias = False | ||
|
|
||
| bridge = LlamaBridge() | ||
| kwargs = bridge.hf_config_to_provider_kwargs(mock_hf_config) | ||
|
|
||
| assert kwargs["num_layers"] == 32 | ||
| assert kwargs["hidden_size"] == 4096 | ||
| assert kwargs["ffn_hidden_size"] == 14336 | ||
| assert kwargs["num_attention_heads"] == 32 | ||
| assert kwargs["num_query_groups"] == 8 | ||
| assert kwargs["vocab_size"] == 128256 | ||
| assert kwargs["seq_length"] == 8192 | ||
| assert kwargs["rotary_base"] == 500000.0 | ||
| assert kwargs["activation_func"] == F.silu |
There was a problem hiding this comment.
Fix test method naming - missing underscore.
Same issue: testhf_config_to_provider_kwargs_returns_correct_mappings should be test_hf_config_to_provider_kwargs_returns_correct_mappings.
Proposed fix
- def testhf_config_to_provider_kwargs_returns_correct_mappings(self):
+ def test_hf_config_to_provider_kwargs_returns_correct_mappings(self):🤖 Prompt for AI Agents
In `@tests/unit_tests/models/llama/test_llama_bridge.py` around lines 398 - 430,
Rename the test function
testhf_config_to_provider_kwargs_returns_correct_mappings to
test_hf_config_to_provider_kwargs_returns_correct_mappings so pytest will
discover it; locate the method in
tests/unit_tests/models/llama/test_llama_bridge.py (the current def
testhf_config_to_provider_kwargs_returns_correct_mappings) and simply insert the
missing underscore in the function name while leaving the body and the call to
LlamaBridge.hf_config_to_provider_kwargs unchanged.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit
New Features
Refactor
Chores