Skip to content

Commit ac92f84

Browse files
committed
delete unused class definition
Signed-off-by: Liu, Kaixuan <[email protected]>
1 parent 0ccb86e commit ac92f84

File tree

1 file changed

+0
-53
lines changed

1 file changed

+0
-53
lines changed

backends/python/server/text_embeddings_server/models/flash_mistral.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -99,59 +99,6 @@ def forward(self, hidden_states):
9999
variance + self.variance_epsilon
100100
)
101101
return self.weight * hidden_states.to(input_dtype)
102-
103-
class GaudiLlamaRotaryEmbedding(torch.nn.Module):
104-
def __init__(self, config: LlamaConfig, device=None):
105-
super().__init__()
106-
107-
# BC: "rope_type" was originally "type"
108-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
109-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
110-
else:
111-
self.rope_type = "default"
112-
self.max_seq_len_cached = config.max_position_embeddings
113-
self.original_max_seq_len = config.max_position_embeddings
114-
115-
if self.rope_type == "linear":
116-
self.scaling_factor = config.rope_scaling["factor"]
117-
elif self.rope_type == "dynamic":
118-
self.scaling_factor = config.rope_scaling["factor"]
119-
self.base = config.rope_theta
120-
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
121-
head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
122-
self.dim = int(head_dim * partial_rotary_factor)
123-
124-
self.config = config
125-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
126-
127-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
128-
self.register_buffer("inv_freq", inv_freq, persistent=False)
129-
self.original_inv_freq = self.inv_freq
130-
131-
# Build here to make `torch.jit.trace` work.
132-
self._set_cos_sin_cache(
133-
seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype()
134-
)
135-
136-
def forward(self, x, seq_len=None):
137-
# x: [bs, num_attention_heads, seq_len, head_size]
138-
139-
if "dynamic" in self.rope_type:
140-
self._dynamic_frequency_update(seq_len, device=x.device)
141-
142-
if seq_len > self.max_seq_len_cached:
143-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
144-
145-
if self.attention_scaling == 1.0:
146-
return (
147-
self._cos_cached[:seq_len].to(dtype=x.dtype),
148-
self._sin_cached[:seq_len].to(dtype=x.dtype),
149-
)
150-
else:
151-
return (
152-
self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
153-
self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
154-
)
155102

156103

157104
class MistralRotaryEmbedding(nn.Module):

0 commit comments

Comments
 (0)