Skip to content

StepFun 3.5 MTP#23274

Merged
pwilkin merged 8 commits into
ggml-org:masterfrom
pwilkin:step35mtp
Jun 2, 2026
Merged

StepFun 3.5 MTP#23274
pwilkin merged 8 commits into
ggml-org:masterfrom
pwilkin:step35mtp

Conversation

@pwilkin

@pwilkin pwilkin commented May 18, 2026

Copy link
Copy Markdown
Member

Overview

MTP implementation for StepFun 3.5.

Additional information

Required a few changes to the core logic because StepFun uses a slightly different MTP architecture - it has 3 MTP layers which are used in a round-robin manner for tokens n+1, n+2 and n+3 respectively.

I'm running a suboptimal setup for testing this, but FWIW testing this on a --cpu-moe StepFun3.5 increased token generation from 15 to 18 t/s.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: Yes, asked CC to mix this per vLLM implementation and Qwen3.5 reference, tested on live Step3.5 model.

@pwilkin pwilkin requested review from a team, CISC and ggerganov as code owners May 18, 2026 13:21
@github-actions github-actions Bot added model Model specific script Script related python python script changes labels May 18, 2026
@pwilkin pwilkin marked this pull request as draft May 18, 2026 21:47
@pwilkin

pwilkin commented May 18, 2026

Copy link
Copy Markdown
Member Author

Converting to draft while I try something out.

@pwilkin pwilkin marked this pull request as ready for review May 19, 2026 09:14
@pwilkin

pwilkin commented May 19, 2026

Copy link
Copy Markdown
Member Author

@ggerganov just FYI because it's related to the cleanup - I modified the TODO-annotated line with top-k - for multi-layer MTP it's a real blocker, the acceptance rate goes up from 0.6 to 0.9 on n = 3 when switching from top-k = 1 to top-k = 10.

@forforever73

Copy link
Copy Markdown
Contributor

Hey, thanks a lot for adding step3.5 support, great work. But One thing I wanted to flag: in process() only call llama_decode once with mtp_step=0, so only mtp1's KV gets the prefill/verify context, mtp2 and mtp3 never see it and end up attending over empty KV during drafting. The correct flow should be
image

@pwilkin pwilkin marked this pull request as draft May 19, 2026 15:31
@pwilkin

pwilkin commented May 19, 2026

Copy link
Copy Markdown
Member Author

Sigh. Back to draft, let's see if I can do anything about that.

@forforever73

forforever73 commented May 20, 2026

Copy link
Copy Markdown
Contributor

@pwilkin The fill in llm_graph_input_mtp_chain_tokens::set_input writes contaminated KV entries, and the pad share is especially bad on the all-accept verify path. Nothing cleans them up afterwards, so they pile up across rounds. And block 0's last-position input should be the token target we just sampled, not a reuse of ubatch->token[n_tokens-1].
image I think this needs two changes together. process() currently runs before the target sampler, so target new isn't visible to the chain graph. If we move the sampler call before process(), we can thread new token into llm_graph_input_mtp_chain_tokens. I haven't traced why the ordering is the way it is, so tell me if there's a constraint I'm missing.

And even with T_new threaded in, block k+1's tail still has k slots that no real token can fill. Those have to come from the previous block's chain sample, same as what the DRAFT branch already does. So for block k+1 in PREFILL: chain-sampled tokens from block k at the tail (last shift = k slots), ubatch-shifted tokens for the rest.

If the analysis holds up, I can put a PR on your step35mtp branch. Or take it yourself if you'd rather — no preference on my end, just let me know.

@pwilkin pwilkin force-pushed the step35mtp branch 2 times, most recently from 3043a4b to c0fad87 Compare May 20, 2026 13:01
@pwilkin

pwilkin commented May 20, 2026

Copy link
Copy Markdown
Member Author

@forforever73 yeah, I see the problem. I've revised the goals for this PR - since doing proper multi-step MTP will require more significant changes, I'll stick to the simple version, which is to only use the first layer similarly to the Qwen3.5 MTP code and save the proper architecture for a future PR (so you're free to propose one).

@pwilkin pwilkin marked this pull request as ready for review May 20, 2026 14:04
@pwilkin

pwilkin commented May 20, 2026

Copy link
Copy Markdown
Member Author

All right, here are the benchmark stats for the single-layer version (only small changes to the iSWA cache code left from core stuff):

Non-MTP (llama-server -m stepfun/IQ4_XS/Step-3.5-Flash-IQ4_XS-00001-of-00004.gguf -c 140000 -ctk q8_0 -ctv q8_0 --cpu-moe):

  code_python        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.2
  code_cpp           pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.4
  explain_concept    pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.4
  summarize          pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.5
  qa_factual         pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.3
  translation        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.6
  creative_short     pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.1
  stepwise_math      pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.1
  long_code_review   pred= 192 draft=   0 acc=   0 rate=n/a tok/s=14.2

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1728,
  "total_draft": 0,
  "total_draft_accepted": 0,
  "aggregate_accept_rate": null,
  "wall_s_total": 144.35
}

MTP (llama-server -m stepfun/IQ4_XS/Step-3.5-Flash-IQ4_XS-00001-of-00004.gguf -c 140000 -ctk q8_0 -ctv q8_0 --spec-type draft-mtp --spec-draft-n-max 3 -md stepfun-mtp.gguf --cpu-moe --spec-draft-p-min 0.65)

  code_python        pred= 192 draft= 123 acc= 105 rate=0.854 tok/s=18.2
  code_cpp           pred= 192 draft= 108 acc=  92 rate=0.852 tok/s=17.1
  explain_concept    pred= 192 draft= 102 acc=  84 rate=0.824 tok/s=16.1
  summarize          pred= 186 draft= 117 acc=  97 rate=0.829 tok/s=16.7
  qa_factual         pred= 192 draft= 131 acc= 111 rate=0.847 tok/s=18.3
  translation        pred= 192 draft= 116 acc=  93 rate=0.802 tok/s=16.9
  creative_short     pred= 192 draft=  95 acc=  85 rate=0.895 tok/s=16.6
  stepwise_math      pred= 192 draft= 121 acc= 112 rate=0.926 tok/s=19.0
  long_code_review   pred= 192 draft= 116 acc= 100 rate=0.862 tok/s=17.0

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1722,
  "total_draft": 1029,
  "total_draft_accepted": 879,
  "aggregate_accept_rate": 0.8542,
  "wall_s_total": 123.49
}

I'd argue it's good enough.

@toastytorque

Copy link
Copy Markdown

4xMI50 32GB, 2xMI50 16GB

apohelios/step3p5_flash_Q4_1-00001-of-00005.gguf
draft size 1 (higher crushes performance)

Non-MTP

  code_python        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.4
  code_cpp           pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.4
  explain_concept    pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.2
  summarize          pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.2
  qa_factual         pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.2
  translation        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.2
  creative_short     pred= 192 draft=   0 acc=   0 rate=n/a tok/s=40.6
  stepwise_math      pred= 192 draft=   0 acc=   0 rate=n/a tok/s=41.0
  long_code_review   pred= 192 draft=   0 acc=   0 rate=n/a tok/s=39.4

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1728,
  "total_draft": 0,
  "total_draft_accepted": 0,
  "aggregate_accept_rate": null,
  "wall_s_total": 47.19
}

MTP

  code_python        pred= 192 draft=  89 acc=  84 rate=0.944 tok/s=52.3
  code_cpp           pred= 192 draft=  83 acc=  77 rate=0.928 tok/s=50.3
  explain_concept    pred= 192 draft=  80 acc=  69 rate=0.863 tok/s=47.8
  summarize          pred= 192 draft=  88 acc=  78 rate=0.886 tok/s=50.8
  qa_factual         pred= 192 draft=  86 acc=  83 rate=0.965 tok/s=53.1
  translation        pred= 182 draft=  87 acc=  81 rate=0.931 tok/s=53.8
  creative_short     pred= 192 draft=  76 acc=  63 rate=0.829 tok/s=46.1
  stepwise_math      pred= 192 draft=  86 acc=  79 rate=0.919 tok/s=51.3
  long_code_review   pred= 192 draft=  79 acc=  74 rate=0.937 tok/s=47.7

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1718,
  "total_draft": 754,
  "total_draft_accepted": 688,
  "aggregate_accept_rate": 0.9125,
  "wall_s_total": 39.54
}

MTP from #20981

  code_python        pred= 192 draft= 104 acc=  87 rate=0.837 tok/s=54.2
  code_cpp           pred= 192 draft= 107 acc=  83 rate=0.776 tok/s=52.4
  explain_concept    pred= 192 draft= 115 acc=  76 rate=0.661 tok/s=49.3
  summarize          pred= 192 draft= 106 acc=  84 rate=0.792 tok/s=52.9
  qa_factual         pred= 192 draft= 100 acc=  90 rate=0.900 tok/s=56.3
  translation        pred= 192 draft= 105 acc=  85 rate=0.809 tok/s=54.2
  creative_short     pred= 183 draft= 111 acc=  71 rate=0.640 tok/s=49.4
  stepwise_math      pred= 192 draft= 104 acc=  87 rate=0.837 tok/s=55.3
  long_code_review   pred= 192 draft= 102 acc=  88 rate=0.863 tok/s=51.9

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1719,
  "total_draft": 954,
  "total_draft_accepted": 751,
  "aggregate_accept_rate": 0.7872,
  "wall_s_total": 39.17
}

@pwilkin Is the stepfun-mtp.gguf from your cmd publicly available?

@pwilkin

pwilkin commented May 20, 2026

Copy link
Copy Markdown
Member Author

@toastytorque yeah, convert_hf_to_gguf.py --remote --outfile stepfun-mtp.gguf --outtype q8_0 --mtp stepfun-ai/Step-3.5-Flash :)

@pwilkin

pwilkin commented May 20, 2026

Copy link
Copy Markdown
Member Author

Also, try --spec-draft-p-min 0.65 with n=3, without p-min the performance above 1 token was indeed atrocious, but that helped a lot.

@toastytorque

Copy link
Copy Markdown

Tests above were indeed with --spec-draft-p-min 0.65.

With n=3 the performance is still lacking, might be a MI50 thing:

  code_python        pred= 192 draft= 128 acc= 102 rate=0.797 tok/s=43.1
  code_cpp           pred= 192 draft= 110 acc=  92 rate=0.836 tok/s=43.8
  explain_concept    pred= 192 draft=  89 acc=  73 rate=0.820 tok/s=42.6
  summarize          pred= 187 draft= 103 acc=  88 rate=0.854 tok/s=43.5
  qa_factual         pred= 192 draft= 127 acc= 104 rate=0.819 tok/s=43.4
  translation        pred= 192 draft= 117 acc=  99 rate=0.846 tok/s=44.3
  creative_short     pred= 155 draft=  77 acc=  54 rate=0.701 tok/s=39.2
  stepwise_math      pred= 192 draft= 127 acc= 105 rate=0.827 tok/s=43.9
  long_code_review   pred= 192 draft= 108 acc=  87 rate=0.806 tok/s=40.8

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1686,
  "total_draft": 986,
  "total_draft_accepted": 804,
  "aggregate_accept_rate": 0.8154,
  "wall_s_total": 44.65
}

It seems I get best results with p-min default (0.0) and n=1:

  code_python        pred= 192 draft= 100 acc=  90 rate=0.900 tok/s=56.0
  code_cpp           pred= 192 draft= 103 acc=  87 rate=0.845 tok/s=54.4
  explain_concept    pred= 192 draft= 111 acc=  80 rate=0.721 tok/s=50.6
  summarize          pred= 192 draft= 106 acc=  84 rate=0.792 tok/s=52.6
  qa_factual         pred= 192 draft= 102 acc=  88 rate=0.863 tok/s=54.7
  translation        pred= 187 draft= 102 acc=  85 rate=0.833 tok/s=53.6
  creative_short     pred= 157 draft=  95 acc=  61 rate=0.642 tok/s=48.3
  stepwise_math      pred= 192 draft= 104 acc=  87 rate=0.837 tok/s=53.7
  long_code_review   pred= 192 draft= 104 acc=  86 rate=0.827 tok/s=51.1

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1688,
  "total_draft": 927,
  "total_draft_accepted": 748,
  "aggregate_accept_rate": 0.8069,
  "wall_s_total": 37.24
}

@pwilkin

pwilkin commented May 21, 2026

Copy link
Copy Markdown
Member Author

Yeah, probably depends on the quant as well.

@slavap

slavap commented May 29, 2026

Copy link
Copy Markdown

@pwilkin
3.7 released today, would be great to get MTP working for it, for me it works seriously better than 3.5

@pwilkin

pwilkin commented May 29, 2026

Copy link
Copy Markdown
Member Author

@am17an any chance we could fast track it? This is the simple version and with fixes on main, no core changes are needed. @forforever73 already said he'd help with the proper full 3-layer support in a followup.

@am17an

am17an commented May 29, 2026

Copy link
Copy Markdown
Contributor

Are these extra scripts required? Seems a bit hacky to have them in master

@pwilkin

pwilkin commented May 29, 2026

Copy link
Copy Markdown
Member Author

Well, they're useful, but I'm willing to throw them out to some helper repo if you think that's better.

@pwilkin

pwilkin commented Jun 1, 2026

Copy link
Copy Markdown
Member Author

@CISC should be GTG?

@ggerganov

Copy link
Copy Markdown
Member

AFAIU this (might) only work with draft size of 1 and we don't have meaningful performance numbers yet to make a conclusion. Probably better to see if the fully functional MTP that uses all MTP layers can be implemented instead of merging this partial work?

@pwilkin

pwilkin commented Jun 1, 2026

Copy link
Copy Markdown
Member Author

@ggerganov all the performance numbers in this thread (from me and from @toastytorque) show a consistent speedup of about 20-25%, I think that's strong enough to warrant inclusion on its own (and this is a very simple patch, the complex solution is bound to be much more complicated and require more turnaround due to the changes to core).

Comment thread conversion/step3.py Outdated
Comment thread conversion/step3.py Outdated
Comment thread conversion/step3.py Outdated
Comment thread conversion/step3.py Outdated
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
@CISC

CISC commented Jun 1, 2026

Copy link
Copy Markdown
Member

Yay, \r\n... :)

@pwilkin

pwilkin commented Jun 1, 2026

Copy link
Copy Markdown
Member Author

Yay, \r\n... :)

Fortunately every Linux distribution still has this wonderful tool named "dos2unix" :)

@toastytorque

Copy link
Copy Markdown

For me the speedup is quite significant and consistent in it's current form (see numbers above), so of course I'd like to see it merged. Just my 2ct though.

@pwilkin

pwilkin commented Jun 2, 2026

Copy link
Copy Markdown
Member Author

@CISC bump?

@CISC

CISC commented Jun 2, 2026

Copy link
Copy Markdown
Member

@CISC bump?

Up to @ggerganov

@pwilkin pwilkin merged commit 2187e00 into ggml-org:master Jun 2, 2026
27 of 28 checks passed
@coder543

coder543 commented Jun 2, 2026

Copy link
Copy Markdown

Do we know if this PR will work for step-3.7-flash as well? I tried to test it, but none of the step-3.7-flash GGUFs that I have came with any MTP layers.

@tarruda

tarruda commented Jun 2, 2026

Copy link
Copy Markdown

@coder543 I don't really know what I'm doing, but I gave this a shot:

  • Extracted MTP model from 3.7 and it appeared to have succeeded
  • Tried to load on llama-server using the CLI args suggested here for 3.5 and it failed:
0.01.964.027 E llama_model_load: error loading model: missing tensor 'blk.0.attn_norm.weight'                                                                            
0.01.964.424 E llama_model_load_from_file_impl: failed to load model                                                                                                     
0.01.965.230 E srv    load_model: [spec] failed to measure draft model memory: failed to load model                                                                      
0.01.965.236 I common_init_result: fitting params to device memory ...                                                                                                   
0.01.965.236 I common_init_result: (for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on)        
0.30.986.638 W llama_context: n_ctx_seq (163840) < n_ctx_train (262144) -- the full capacity of the model will not be utilized                                           
0.31.596.105 I srv    load_model: loading draft model '/Users/thiago/step-3.7-flash/Step-3.7-Flash-MTP.gguf'                                                             
0.31.717.002 E llama_model_load: error loading model: missing tensor 'blk.0.attn_norm.weight'                                                                            
0.31.717.010 E llama_model_load_from_file_impl: failed to load model                                                                                                                                                                                                                                                                              
0.31.717.012 E srv    load_model: failed to load draft model, '/Users/thiago/step-3.7-flash/Step-3.7-Flash-MTP.gguf'                                                                                                                                                                                                                               
0.31.717.015 I srv    operator(): operator(): cleaning up before exit...                                                                                                 
0.31.717.396 E srv  llama_server: exiting due to model loading error 

@AesSedai

AesSedai commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

@coder543 It should work for 3.7-Flash. I'm working on converting and updating my quants with it.

arichiardi pushed a commit to arichiardi/llama.cpp that referenced this pull request Jun 2, 2026
* StepFun 3.5 MTP

* Simplify to single layer

* Rollback core changes

* fix flake8 errors

* Remove scripts

* modify to convention

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* dos2unix

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
@slavap

slavap commented Jun 3, 2026

Copy link
Copy Markdown

MTP draft model only can be taken from https://huggingface.co/notSnix/Step-3.7-Flash-Q4_K_M-MTP-GGUF
It works with official Stepfun Q4_K_S and IQ4_XS
The best performance in my tests (on StrixHalo) when: --spec-draft-n-max 2 --spec-draft-p-min 0.0
Unfortunately -ctkd q8_0 -ctvd q8_0 do not work for draft model :-( Bug: #24040

@forforever73

forforever73 commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

I uploaded an MTP model converted with the latest code here: https://huggingface.co/stepfun-ai/Step-3.7-Flash-GGUF/tree/main It works with --spec-draft-model.

But I didn't observe any speedup on an Apple M4 Max.

./llama-server \
    -m ~/Downloads/Step-3.7-IQ4_XS.gguf \
    --spec-type draft-mtp \
    --spec-draft-model ~/Downloads/Step3.7-flash-mtp-Q8_0.gguf \
    -ngl all \
    --spec-draft-ngl all \
    --spec-draft-device MTL0 \
    -c 35000 \
    -np 1 \
    -b 2048 \
    -ub 1024 \
    --temp 0 \
    --spec-draft-n-max 1 \
    --spec-draft-p-min 0.6 \
    --host 127.0.0.1 \
    --port 8080

no mtp:

  code_python        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=50.3
  code_cpp           pred= 192 draft=   0 acc=   0 rate=n/a tok/s=49.9
  explain_concept    pred= 192 draft=   0 acc=   0 rate=n/a tok/s=50.3
  summarize          pred= 192 draft=   0 acc=   0 rate=n/a tok/s=50.1
  qa_factual         pred= 161 draft=   0 acc=   0 rate=n/a tok/s=51.2
  translation        pred= 192 draft=   0 acc=   0 rate=n/a tok/s=50.9
  creative_short     pred= 192 draft=   0 acc=   0 rate=n/a tok/s=51.0
  stepwise_math      pred= 192 draft=   0 acc=   0 rate=n/a tok/s=50.7
  long_code_review   pred= 192 draft=   0 acc=   0 rate=n/a tok/s=48.9

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1697,
  "total_draft": 0,
  "total_draft_accepted": 0,
  "aggregate_accept_rate": null,
  "wall_s_total": 38.69
}

mtp:

  code_python        pred= 192 draft=  91 acc=  84 rate=0.923 tok/s=49.4
  code_cpp           pred= 192 draft=  85 acc=  73 rate=0.859 tok/s=46.8
  explain_concept    pred= 192 draft=  92 acc=  81 rate=0.880 tok/s=48.1
  summarize          pred= 192 draft=  88 acc=  82 rate=0.932 tok/s=49.0
  qa_factual         pred= 161 draft=  78 acc=  76 rate=0.974 tok/s=50.9
  translation        pred= 192 draft=  86 acc=  81 rate=0.942 tok/s=49.2
  creative_short     pred= 192 draft=  77 acc=  67 rate=0.870 tok/s=46.5
  stepwise_math      pred= 192 draft=  88 acc=  86 rate=0.977 tok/s=50.5
  long_code_review   pred= 192 draft=  90 acc=  87 rate=0.967 tok/s=49.5

Aggregate: {
  "n_requests": 9,
  "total_predicted": 1697,
  "total_draft": 775,
  "total_draft_accepted": 717,
  "aggregate_accept_rate": 0.9252,
  "wall_s_total": 40.02
}

Are there any parameter recommendations ?

@slavap

slavap commented Jun 3, 2026

Copy link
Copy Markdown

@forforever73
For me "--spec-draft-n-max 2 --spec-draft-p-min 0.0" works good, adds ~20% to TG on StrixHalo

@forforever73

Copy link
Copy Markdown
Contributor

@slavap Interesting. On my device, those parameters make performance worse(35.93 t/s). Looks like I'll need to spend some time investigating it further.

@pwilkin

pwilkin commented Jun 3, 2026

Copy link
Copy Markdown
Member Author

@forforever73 you might try different combinations of draft-n-max and draft-min-p, some people report success with n_max = 1, I've had success with n_max = 3 and min-p = 0.6. Or you can whip up that implementation with the full 3-layer MTP model and we can work on that one ;)

@forforever73

Copy link
Copy Markdown
Contributor

@pwilkin tried several parameter combinations, but all seem to make things slower. A bit strange
Anyway, I'll be working on mtp-3 support in the near future and may open a draft pr once I have something workable. Happy to discuss ideas if you have any !

@coder543

coder543 commented Jun 3, 2026

Copy link
Copy Markdown

I tried several Step-3.7-Flash MTP models, and most of them used an immense amount of memory (compared to not using MTP), so I couldn't really do anything with them.

@AesSedai's IQ4 worked well enough for me to fit more than 128k tokens of context with MTP enabled on a DGX Spark.

Some performance results on the DGX Spark using draft-p-min of 0.6:

Mode No MTP MTP 1 MTP 2 MTP 3
A tok/s 23.68 31.42 32.37 27.87
A uplift - +32.7% +36.7% +17.7%
B tok/s 23.74 31.97 36.56 32.46
B uplift - +34.6% +54.0% +36.7%
Avg uplift - +33.7% +45.3% +27.2%
A acceptance - 84.9% 72.0% 71.4%
B acceptance - 79.0% 79.5% 73.7%
Avg acceptance - 82.0% 75.8% 72.6%

A prompt is "What is the LHC?"

B prompt is "Write a TypeScript React example."

@slavap

slavap commented Jun 4, 2026

Copy link
Copy Markdown

@coder543 what about -ctkd and -ctvd broken? On StrixHalo it is major pain, vram is very limited, the same should be on Spark? #24040

@coder543

coder543 commented Jun 4, 2026

Copy link
Copy Markdown

This test was using f16 KV cache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes script Script related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants