|
3 | 3 | from keras import layers |
4 | 4 | from keras import ops |
5 | 5 |
|
6 | | -from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb |
7 | | -from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward |
8 | | -from keras_hub.src.models.smollm3.smollm3_utils import rope_init |
9 | 6 | from keras_hub.src.layers.modeling.transformer_layer_utils import ( |
10 | | - merge_padding_and_attention_mask, |
| 7 | + compute_causal_mask, |
11 | 8 | ) |
12 | 9 | from keras_hub.src.layers.modeling.transformer_layer_utils import ( |
13 | | - compute_causal_mask, |
| 10 | + merge_padding_and_attention_mask, |
14 | 11 | ) |
| 12 | +from keras_hub.src.models.smollm3.smollm3_utils import apply_rotary_pos_emb |
| 13 | +from keras_hub.src.models.smollm3.smollm3_utils import eager_attention_forward |
| 14 | +from keras_hub.src.models.smollm3.smollm3_utils import rope_init |
| 15 | + |
15 | 16 |
|
16 | 17 | class SmolLM3Attention(layers.Layer): |
17 | 18 | """ |
@@ -372,7 +373,6 @@ def __init__( |
372 | 373 |
|
373 | 374 | self.attention_type = layer_types[layer_idx] |
374 | 375 |
|
375 | | - |
376 | 376 | def _compute_self_attention_mask( |
377 | 377 | self, |
378 | 378 | decoder_sequence, |
@@ -460,7 +460,9 @@ def call( |
460 | 460 | training: Whether the layer is in training mode. |
461 | 461 | """ |
462 | 462 | self_attention_cache = kwargs.get("self_attention_cache", None) |
463 | | - self_attention_cache_update_index = kwargs.get("self_attention_cache_update_index", None) |
| 463 | + self_attention_cache_update_index = kwargs.get( |
| 464 | + "self_attention_cache_update_index", None |
| 465 | + ) |
464 | 466 |
|
465 | 467 | self_attention_mask = self._compute_self_attention_mask( |
466 | 468 | decoder_sequence=hidden_states, |
|
0 commit comments