|
3 | 3 | import enum |
4 | 4 | import functools |
5 | 5 | import os |
| 6 | +import re |
6 | 7 | import typing |
7 | 8 |
|
8 | 9 | import pytest |
@@ -76,7 +77,7 @@ class ModelTestingConfig: |
76 | 77 | groups: dict[ModelTestingGroup, ModelTestingGroupAction] |
77 | 78 | # Scale the comparison thresholds for specific models. |
78 | 79 | compare_factor: float = 1.0 |
79 | | - # Option to skip specific distributed configuration with name containing any of the provided strings. |
| 80 | + # Option to skip specific distributed configuration with name matching any of the provided regex patterns. |
80 | 81 | skip_tests: tuple[str] = () |
81 | 82 |
|
82 | 83 | @functools.cached_property |
@@ -125,7 +126,7 @@ def base_model_config_class(self): |
125 | 126 | return self.model_config_class.get_base_model_config_class() |
126 | 127 |
|
127 | 128 | def should_skip(self, distributed_config: DistributedTestingConfig) -> bool: |
128 | | - return any(key in distributed_config.name for key in self.skip_tests) |
| 129 | + return any(re.search(pattern, distributed_config.name) for pattern in self.skip_tests) |
129 | 130 |
|
130 | 131 |
|
131 | 132 | def _update_and_add_testing_config( |
@@ -470,7 +471,7 @@ def _update_and_add_testing_config( |
470 | 471 | }, |
471 | 472 | compare_factor=2.0, |
472 | 473 | # Arg update for cross-entropy splits doesn't work here. |
473 | | - skip_tests=("ce4", "ms"), |
| 474 | + skip_tests=(r"ce4", r"ms"), |
474 | 475 | ) |
475 | 476 |
|
476 | 477 | _update_and_add_testing_config( |
@@ -603,7 +604,7 @@ def _update_and_add_testing_config( |
603 | 604 | }, |
604 | 605 | compare_factor=2.0, |
605 | 606 | # Micro-sequence split not supported. |
606 | | - skip_tests=("sdp", "ms"), |
| 607 | + skip_tests=(r"sdp", r"ms"), |
607 | 608 | ) |
608 | 609 |
|
609 | 610 | _update_and_add_testing_config( |
@@ -645,8 +646,8 @@ def _update_and_add_testing_config( |
645 | 646 | compare_factor=2.0, |
646 | 647 | # Micro-sequence split not supported. |
647 | 648 | skip_tests=( |
648 | | - "sdp", |
649 | | - "ms", |
| 649 | + r"sdp", |
| 650 | + r"ms", |
650 | 651 | ), # "pp","dp", "ce","16", "bf", "df", "stp"), |
651 | 652 | ) |
652 | 653 |
|
@@ -690,7 +691,7 @@ def _update_and_add_testing_config( |
690 | 691 | }, |
691 | 692 | compare_factor=2.0, |
692 | 693 | # Micro-sequence split and sequence-first not supported. |
693 | | - skip_tests=("sdp", "ms"), |
| 694 | + skip_tests=(r"sdp", r"ms"), |
694 | 695 | ) |
695 | 696 |
|
696 | 697 |
|
@@ -728,8 +729,10 @@ def _update_and_add_testing_config( |
728 | 729 | ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, |
729 | 730 | ModelTestingGroup.distributed: ModelTestingGroupAction.normal, |
730 | 731 | }, |
731 | | - compare_factor=16, |
732 | | - skip_tests=("sdp", "ms", "stp"), |
| 732 | + compare_factor=10.0, # with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla (passes with local non-fla norm) |
| 733 | + # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). |
| 734 | + # we should be using STP with this model! |
| 735 | + skip_tests=(r"sdp", r"ms", r"^tp2$"), |
733 | 736 | ) |
734 | 737 |
|
735 | 738 |
|
|
0 commit comments