enable GPT-OSS#2214
Conversation
fused RoPE not enabled yet
|
The code quality check failed, please run |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
for prefill, keep the original as much as possible for decode (+use_kv_cache), find the token idx from attention_mask and mask tokens before (token_idx - sliding_window) as -inf
|
The code quality check failed, please run |
|
The code quality check failed, please run |
1 similar comment
|
The code quality check failed, please run |
b4d368e to
d611aef
Compare
|
The code quality check failed, please run |
d611aef to
3e88b29
Compare
|
The code quality check failed, please run |
3e88b29 to
97725a2
Compare
| # When sliding_window is not None, find the token_idx by chechking the last idx of 1 in attention_mask_2d | ||
| if input_shape[-1] == 1: | ||
| cumsum = attention_mask_2d.cumsum(dim=1) | ||
| token_idx = cumsum.argmax(dim=1, keepdim=True)[0] |
There was a problem hiding this comment.
| token_idx = cumsum.argmax(dim=1, keepdim=True)[0] | |
| token_idx = cumsum.argmax(dim=1, keepdim=True)[0].item() |
Extract the token index as an integer from the cumulative attention mask for later use in _make_causal_mask
There was a problem hiding this comment.
this change causes significant perf drop. so instead i updated the type hint from int to torch.Tensor
token_idx: Optional[torch.Tensor] = None,
cb82e5c to
bd4a77d
Compare
|
@regisss added a basic functional test. please let me know if you want to add more test cases! |
e799e23 to
bd4a77d
Compare
|
The code quality check failed, please run |
regisss
left a comment
There was a problem hiding this comment.
LEt's add GPT-OSS to the table in the README and in the docs please
|
@schoi-habana What is the status of this PR? Is everything on your side done or do we still need some changes? |
|
@pbielak it's ready. please review |
|
No additional comments from my side - please go ahead with the merge @regisss |
|
I just pushed one more commit to fix the test name in |
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: Sun Choi <schoi@habana.ai> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com>
Co-authored-by: Sun Choi <schoi@habana.ai> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> Co-authored-by: Adam Stachowicz <105052242+astachowiczhabana@users.noreply.github.com>
|
Hi @regisss, this PR is only to support 20B model?, i am able to infer 20B model without any issues but when I tried running 120B model with gaudi_spwan.py got the below issue, it's like OOM. Command used: Output LOG: |
|
@jaideepsai-narayan We don't test the 120B checkpoint in CI so I'm not sure it's supposed to work. Maybe @schoi-habana knows more about that? |
Yes @regisss we are trying to run on Gaudi3 |
|
@jaideepsai-narayan I just tried to run it on a Gaudi3 server with SynapseAI v1.22 and got the same error. I think the issue here is that quantized checkpoints (which actually are the original checkpoints) rely on the mxfp4 data type which is not supported by Gaudi3. And there doesn't seem to be other quantized versions available. |
|
Thank you so much @regisss, Do you have any timeline on when MXFP4 support for Gaudi3 (or compatible quantized checkpoints) will be implemented? |
|
Unfortunately I don't have any visibility on Gaudi's roadmap, maybe folks from Intel have more information. But I guess this is a hardware-related constraint so I don't think Gaudi3 will ever be compatible with mxfp4... |
dependant to #2209 as gpt-oss is added in huggingface 4.55.0
accuracy comparison to the baseline
What does this PR do?
Fixes # (issue)
Before submitting