DeepSeek_v3 support#1735
Conversation
|
@srajabos FYI there is an open PR to add Deepseek V3 to Transformers: huggingface/transformers#35926 We won't be able to rely on the Transformers implementation before Transformers v4.49 is released, but I thought this might be interesting to you. |
|
@regiss, I'll keep this as draft until verified with Transformers v4.49. |
Deepseek V3 (and hence R1) requriements.txt says the minimum version of transformer required is 4.46.3 |
|
@anishagartia, currently we are adding the model files and optimizing for Gaudi. Once we have performant data the plan is to get it in. Thanks for the link. |
17f62bd to
b2b1715
Compare
|
@yao-matrix @gyou2021 @IT-Forrest - kindly review the code. |
ssarkar2
left a comment
There was a problem hiding this comment.
[explanatory] are just comments to help follow the hpu code. no changes required for those comments. sorry for spamming comments in this category, thought it might be useful for future readers going thru the change and for others looking ot port similar models
[clarifications] some question from my end. Sometimes these are marked with [minor] if they are minor nitpicks
| from habana_frameworks.torch.hpex.kernels import FusedSDPA | ||
| except ImportError: | ||
| print("Not using HPU fused scaled dot-product attention kernel.") | ||
| FusedSDPA = None |
There was a problem hiding this comment.
[explanatory] Import hpu fused ops
|
|
||
| def forward(self, hidden_states): | ||
| if hidden_states.device.type == "hpu" and FusedRMSNorm: | ||
| # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype |
There was a problem hiding this comment.
[explanatory] use fused ops
| self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
|
||
| # Build here to make `torch.jit.trace` work. | ||
| self.max_seq_len_cached = max_position_embeddings |
There was a problem hiding this comment.
[explanatory] make it static (max_position_embeddings ) instead of updating depending on longest eq_len seen till now: "seq_len > self.max_seq_len_cached"
|
|
||
| def apply_customized_rope(q, k, cos, sin, position_ids): | ||
| if q.device.type == "hpu" and FusedRoPE: | ||
| return FusedRoPE.apply( |
There was a problem hiding this comment.
[explanatory] fused hpu op
There was a problem hiding this comment.
[clarification][minor] Could we call apply_customized_rope here?
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
| return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() | ||
|
|
||
| def split_kv_b_proj(self): |
There was a problem hiding this comment.
[clarification] this is present only in deepseek attention (v2/v3). Can we add some comment about this?
| self.q_absorb = kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) | ||
| self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) | ||
|
|
||
| def compress_kv( |
There was a problem hiding this comment.
[clarification] this is present only in deepseek attention (v2/v3). Can we add some comment about this? In the original deepseek code this is not a function, any particular reason of functionify-ing this? just want to clarify if making this a function is a stylistic choice or there is some reason
| key_states, value_states, self.layer_idx, cache_kwargs | ||
| ) | ||
| # optimization | ||
| if use_flash_attention and FusedSDPA is not None: |
There was a problem hiding this comment.
[explanatory] hpu specific, similar to other modelling files in OH
|
|
||
| past_key_values_length = 0 | ||
| if past_key_values is not None: | ||
| past_key_values_length = past_key_values[0][0].shape[2] |
There was a problem hiding this comment.
[explanatory] hpu kv cache management, similar to other OH models
| # Maximum number of experts supported by dynamic MoE op (mixture_of_experts) | ||
| SLICE_MAX_EXPERT = 80 | ||
|
|
||
| # import hpu fused ops |
| # Build here to make `torch.jit.trace` work. | ||
|
|
||
| # make it static (max_position_embeddings) instead of updating depending on | ||
| # longest eq_len seen till now: seq_len > self.max_seq_len_cached |
There was a problem hiding this comment.
[minor] eq_len -> seq_len
e22cd28 to
1ded31d
Compare
Copied from transformers v4.48.2 for DeepSeek-R1 support. Delete after upgrade transformers v4.45.2 to v4.48
1ded31d to
e02c6de
Compare
regisss
left a comment
There was a problem hiding this comment.
Nice! I left a couple of comments
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
DeepSeek v3 support on OH
Before submitting