@@ -99,59 +99,6 @@ def forward(self, hidden_states):
9999 variance + self .variance_epsilon
100100 )
101101 return self .weight * hidden_states .to (input_dtype )
102-
103- class GaudiLlamaRotaryEmbedding (torch .nn .Module ):
104- def __init__ (self , config : LlamaConfig , device = None ):
105- super ().__init__ ()
106-
107- # BC: "rope_type" was originally "type"
108- if hasattr (config , "rope_scaling" ) and config .rope_scaling is not None :
109- self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
110- else :
111- self .rope_type = "default"
112- self .max_seq_len_cached = config .max_position_embeddings
113- self .original_max_seq_len = config .max_position_embeddings
114-
115- if self .rope_type == "linear" :
116- self .scaling_factor = config .rope_scaling ["factor" ]
117- elif self .rope_type == "dynamic" :
118- self .scaling_factor = config .rope_scaling ["factor" ]
119- self .base = config .rope_theta
120- partial_rotary_factor = config .partial_rotary_factor if hasattr (config , "partial_rotary_factor" ) else 1.0
121- head_dim = getattr (config , "head_dim" , config .hidden_size // config .num_attention_heads )
122- self .dim = int (head_dim * partial_rotary_factor )
123-
124- self .config = config
125- self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
126-
127- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
128- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
129- self .original_inv_freq = self .inv_freq
130-
131- # Build here to make `torch.jit.trace` work.
132- self ._set_cos_sin_cache (
133- seq_len = self .max_seq_len_cached , device = self .inv_freq .device , dtype = torch .get_default_dtype ()
134- )
135-
136- def forward (self , x , seq_len = None ):
137- # x: [bs, num_attention_heads, seq_len, head_size]
138-
139- if "dynamic" in self .rope_type :
140- self ._dynamic_frequency_update (seq_len , device = x .device )
141-
142- if seq_len > self .max_seq_len_cached :
143- self ._set_cos_sin_cache (seq_len = seq_len , device = x .device , dtype = x .dtype )
144-
145- if self .attention_scaling == 1.0 :
146- return (
147- self ._cos_cached [:seq_len ].to (dtype = x .dtype ),
148- self ._sin_cached [:seq_len ].to (dtype = x .dtype ),
149- )
150- else :
151- return (
152- self ._cos_cached [:seq_len ].to (dtype = x .dtype ) * self .attention_scaling ,
153- self ._sin_cached [:seq_len ].to (dtype = x .dtype ) * self .attention_scaling ,
154- )
155102
156103
157104class MistralRotaryEmbedding (nn .Module ):
0 commit comments