Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/transformers/models/jamba/configuration_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(
self.attn_layer_period = attn_layer_period
self.attn_layer_offset = attn_layer_offset

self._check_supported_offset("attention", self.attn_layer_period, self.attn_layer_offset)
self._check_supported_offset("expert", self.expert_layer_period, self.expert_layer_offset)

self.use_mamba_kernels = use_mamba_kernels
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
Expand Down Expand Up @@ -222,3 +225,9 @@ def layers_num_experts(self):
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
for i in range(self.num_hidden_layers)
]

def _check_supported_offset(self, property_: str, period: int, offset: int):
if offset >= period:
raise ValueError(
f"{property_} layer offset ({offset}) must be smaller than {property_} layer period ({period})"
)
44 changes: 43 additions & 1 deletion tests/models/jamba/test_modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,48 @@
)


class JambaConfigTester(ConfigTester):
def _create_attn_config(self, attn_layer_offset: int, attn_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["attn_layer_offset"] = attn_layer_offset
_input_dict["attn_layer_period"] = attn_layer_period
return self.config_class(**_input_dict)

def _create_expert_config(self, expert_layer_offset: int, expert_layer_period: int):
_input_dict = self.inputs_dict.copy()
_input_dict["expert_layer_offset"] = expert_layer_offset
_input_dict["expert_layer_period"] = expert_layer_period
return self.config_class(**_input_dict)

def test_attn_offsets(self):
self._create_attn_config(attn_layer_offset=0, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=1, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=2, attn_layer_period=4)
self._create_attn_config(attn_layer_offset=3, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=4, attn_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_attn_config(attn_layer_offset=5, attn_layer_period=4)

def test_expert_offsets(self):
self._create_expert_config(expert_layer_offset=0, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=1, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=2, expert_layer_period=4)
self._create_expert_config(expert_layer_offset=3, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=4, expert_layer_period=4)
with self.parent.assertRaises(ValueError):
self._create_expert_config(expert_layer_offset=5, expert_layer_period=4)

def test_jamba_offset_properties(self):
self.test_attn_offsets()
self.test_expert_offsets()

def run_common_tests(self):
self.test_jamba_offset_properties()
return super().run_common_tests()


class JambaModelTester:
def __init__(
self,
Expand Down Expand Up @@ -302,7 +344,7 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi

def setUp(self):
self.model_tester = JambaModelTester(self)
self.config_tester = ConfigTester(self, config_class=JambaConfig, hidden_size=37)
self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37)

def test_config(self):
self.config_tester.run_common_tests()
Expand Down