Skip to content

Commit 59db0d6

Browse files
committed
gdn tests
1 parent cff4a5c commit 59db0d6

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

fast_llm/layers/common/normalization/normalization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,14 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float |
311311
super().__init__(config, hidden_dim, lr_scale)
312312

313313
if rms_norm_gated is not None:
314-
self._forward = self._forward_fused
314+
self._forward_gated = self._forward_local
315315
else:
316-
self._forward = self._forward
316+
self._forward_gated = self._forward_local
317317

318318
def forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
319-
return self._forward(input_.view(-1, *self._normalized_shape), gate).view_as(input_)
319+
return self._forward_gated(input_.view(-1, *self._normalized_shape), gate).view_as(input_)
320320

321-
def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
321+
def _forward_fla(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
322322
return rms_norm_gated(
323323
input_,
324324
gate,
@@ -331,6 +331,6 @@ def _forward_fused(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tens
331331
residual_in_fp32=False,
332332
)
333333

334-
def _forward(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
335-
normalized = self.rmsnorm(input_)
334+
def _forward_local(self, input_: torch.Tensor, gate: torch.Tensor) -> torch.Tensor:
335+
normalized = self._forward(input_)
336336
return normalized * F.silu(gate)

fast_llm/layers/ssm/gdn.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,9 @@ def __init__(
210210
lr_scale=self._lr_scale,
211211
peft=self._peft,
212212
)
213-
# self.norm = self._config.normalization.get_layer(
214-
# self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft
215-
# )
213+
self.norm = self._config.normalization.get_layer(
214+
self._value_head_dim, lr_scale=self._lr_scale, peft=self._peft
215+
)
216216

217217
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
218218

@@ -259,7 +259,6 @@ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
259259
Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
260260
"""
261261

262-
# Split contiguous q/k/v/z blocks and only then project them into per-head shapes.
263262
local_qkv_sizes = (
264263
self._local_key_heads * self._config.key_head_dim,
265264
self._local_key_heads * self._config.key_head_dim,
@@ -370,7 +369,7 @@ def _forward(
370369

371370
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
372371
z = z.reshape(-1, z.shape[-1])
373-
# core_attn_out = self.norm(core_attn_out, z)
372+
core_attn_out = self.norm(core_attn_out, z)
374373
core_attn_out = core_attn_out.reshape(z_shape_og)
375374
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
376375
if sequence_first:

tests/utils/model_configs.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import enum
44
import functools
55
import os
6+
import re
67
import typing
78

89
import pytest
@@ -76,7 +77,7 @@ class ModelTestingConfig:
7677
groups: dict[ModelTestingGroup, ModelTestingGroupAction]
7778
# Scale the comparison thresholds for specific models.
7879
compare_factor: float = 1.0
79-
# Option to skip specific distributed configuration with name containing any of the provided strings.
80+
# Option to skip specific distributed configuration with name matching any of the provided regex patterns.
8081
skip_tests: tuple[str] = ()
8182

8283
@functools.cached_property
@@ -125,7 +126,7 @@ def base_model_config_class(self):
125126
return self.model_config_class.get_base_model_config_class()
126127

127128
def should_skip(self, distributed_config: DistributedTestingConfig) -> bool:
128-
return any(key in distributed_config.name for key in self.skip_tests)
129+
return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests)
129130

130131

131132
def _update_and_add_testing_config(
@@ -470,7 +471,7 @@ def _update_and_add_testing_config(
470471
},
471472
compare_factor=2.0,
472473
# Arg update for cross-entropy splits doesn't work here.
473-
skip_tests=("ce4", "ms"),
474+
skip_tests=(r"ce4", r"ms"),
474475
)
475476

476477
_update_and_add_testing_config(
@@ -603,7 +604,7 @@ def _update_and_add_testing_config(
603604
},
604605
compare_factor=2.0,
605606
# Micro-sequence split not supported.
606-
skip_tests=("sdp", "ms"),
607+
skip_tests=(r"sdp", r"ms"),
607608
)
608609

609610
_update_and_add_testing_config(
@@ -645,8 +646,8 @@ def _update_and_add_testing_config(
645646
compare_factor=2.0,
646647
# Micro-sequence split not supported.
647648
skip_tests=(
648-
"sdp",
649-
"ms",
649+
r"sdp",
650+
r"ms",
650651
), # "pp","dp", "ce","16", "bf", "df", "stp"),
651652
)
652653

@@ -690,7 +691,7 @@ def _update_and_add_testing_config(
690691
},
691692
compare_factor=2.0,
692693
# Micro-sequence split and sequence-first not supported.
693-
skip_tests=("sdp", "ms"),
694+
skip_tests=(r"sdp", r"ms"),
694695
)
695696

696697

@@ -728,8 +729,10 @@ def _update_and_add_testing_config(
728729
ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented,
729730
ModelTestingGroup.distributed: ModelTestingGroupAction.normal,
730731
},
731-
compare_factor=16,
732-
skip_tests=("sdp", "ms", "stp"),
732+
compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm)
733+
# note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though).
734+
# we should be using STP with this model!
735+
skip_tests=(r"sdp", r"ms", r"^tp2$"),
733736
)
734737

735738

0 commit comments

Comments
 (0)