Skip to content

Commit 0fe64d3

Browse files
committed
tests
1 parent 59db0d6 commit 0fe64d3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

fast_llm/layers/common/normalization/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
311311
super().__init__(config, hidden_dim, lr_scale)
312312

313313
if rms_norm_gated is not None:
314-
self._forward_gated = self._forward_local
314+
self._forward_gated = self._forward_fla
315315
else:
316316
self._forward_gated = self._forward_local
317317

tests/utils/model_configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -729,9 +729,9 @@ def _update_and_add_testing_config(
729729
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
730730
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
731731
},
732-
compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm)
732+
compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla
733733
# note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though).
734-
# we should be using STP with this model!
734+
# we should be using STP with this model, not TP!
735735
skip_tests=(r"sdp", r"ms", r"^tp2$"),
736736
)
737737

0 commit comments

Comments
 (0)