Skip to content

Conversation

@ORippler
Copy link
Contributor

Gemma3n uses Matrix-Matrix addition as part of project_per_layer_input, erroneously triggering CUDA_GRAPH disablement on NVGPUs even when a batch-size of 1 is used. This PR fixes this issue, while still detecting batched execution for graphs with > 1 GGML_OP_ADD node.

Perf before:

| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |    pp1000+tg200 |         47.86 ± 1.27 |

Perf after:

| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |    pp1000+tg200 |        133.08 ± 0.23 |

In the long run, I feel we should either fully support batched inference with CUDA Graphs or refactor the way batch sizes are detected (maybe moving ownership elsewhere?), but I'm still too unfamiliar with the code base to mage suggestions here.

Thoughts?

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jul 17, 2025
@slaren
Copy link
Member

slaren commented Jul 17, 2025

In the long run the solution will be to move the implementation to the graph plan API, then the heuristics to determine if the graph should be captured or not will be removed. I cannot tell if this workaround will break something else.

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jul 17, 2025
ORippler added 2 commits July 18, 2025 02:03
Gemma3n uses Matrix-Matrix addition as part of their input processing,
wrongly triggering CUDA_GRAPH disablement on NVGPUs even when batch-size
of 1 is used.
This ensures that all other graphs which don't exhibit this pattern do
not have their behavior changed.
@ORippler ORippler force-pushed the hotfix_perf_issues_of_gemma3n branch from 8e35380 to a52592a Compare July 18, 2025 10:11
@ORippler
Copy link
Contributor Author

ORippler commented Jul 18, 2025

In the long run the solution will be to move the implementation to the graph plan API, then the heuristics to determine if the graph should be captured or not will be removed.

Happy to contribute here, and open to pointers on where to get started!

I cannot tell if this workaround will break something else.

In the mean time, we would nevertheless like for gemma3n to be executable as CUDA GRAPHs on NVPGPUs. I therefore refactored the workaround to now rely on pattern matching of node names. Specifically, it matches for the names of gemma3n's input projection nodes, excluding this one specific occurrence from batch-size determination. As neither the individual node names nor their combined pattern exist in a model graph besides the one of gemma3n (verified this manually), the workaround will not break existing models. While it may potentially break new models that end up using the exact same node pattern, this will come up during the implementation of each new model and can be fixed then.

I moreover verified that gemma3n's output is matching for both CUDA GRAPH and plain CUDA kernel execution for batch-size 1, i.e. that the input projection operation can be "mini-batched" within a CUDA GRAPH when overall batch size is 1.

Last, I generated some e2e performance numbers to high-light the speedups we can gain for gemma3n and to verify that the check does not degrade performance of existing models.

model size params backend ngl n_batch test t/s code
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 44.93 ± 0.63 eacdeb5 (master)
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 139.68 ± 0.54 a52592a (this PR)
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 43.97 ± 0.31 eacdeb5 (master)
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 123.94 ± 1.03 a52592a (this PR)
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 2 pp100 92.39 ± 0.80 eacdeb5 (master)
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 2 pp100 93.14 ± 5.71 a52592a (this PR)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 1 pp100 324.02 ± 1.41 eacdeb5 (master)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 1 pp100 323.77 ± 0.82 a52592a (this PR)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 1 tg100 288.85 ± 3.13 eacdeb5 (master)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 1 tg100 289.96 ± 1.75 a52592a (this PR)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 2 pp100 270.51 ± 15.25 eacdeb5 (master)
llama 3B Q4_K - Medium 1.87 GiB 3.21 B CUDA 99 2 pp100 263.20 ± 18.06 a52592a (this PR)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 1 pp100 239.89 ± 1.97 eacdeb5 (master)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 1 pp100 240.61 ± 0.89 a52592a (this PR)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 1 tg100 214.54 ± 1.11 eacdeb5 (master)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 1 tg100 215.57 ± 1.91 a52592a (this PR)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 2 pp100 184.03 ± 3.30 eacdeb5 (master)
qwen3 4B Q4_K - Medium 2.44 GiB 4.02 B CUDA 99 2 pp100 186.10 ± 3.88 a52592a (this PR)

Let me know should you still see major issues, and I'll give my best to address them.

@slaren
Copy link
Member

slaren commented Jul 18, 2025

Happy to contribute here, and open to pointers on where to get started!

There are some changes that we need to make to ggml-backend and llama.cpp before this can be implemented in the CUDA backend. I will ping you and other NVIDIA developers who also expressed interest when this is done.

@slaren slaren merged commit 021cc28 into ggml-org:master Jul 18, 2025
47 checks passed
@ORippler ORippler deleted the hotfix_perf_issues_of_gemma3n branch July 18, 2025 11:53
ORippler added a commit to ORippler/ollama that referenced this pull request Jul 18, 2025
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Jul 18, 2025
ORippler added a commit to ORippler/ollama that referenced this pull request Jul 22, 2025
Similar to
ggml-org/llama.cpp#14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.
ORippler added a commit to ORippler/ollama that referenced this pull request Jul 22, 2025
Similar to
ggml-org/llama.cpp#14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.
ORippler added a commit to ORippler/ollama that referenced this pull request Jul 25, 2025
Similar to
ggml-org/llama.cpp#14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.
mxyng pushed a commit to ollama/ollama that referenced this pull request Jul 29, 2025
…aph execution (#11525)

* Enable CUDA Graphs for gemma3n.

Similar to
ggml-org/llama.cpp#14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.

* Remove residual check by reshaping differently in gemma3n model

This should make the heuristics more robust
gabe-l-hart added a commit to gabe-l-hart/ollama that referenced this pull request Jul 30, 2025
It was implemented upstream:
ggml-org/llama.cpp#14741

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
mxyng pushed a commit to ollama/ollama that referenced this pull request Aug 9, 2025
It was implemented upstream:
ggml-org/llama.cpp#14741

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
mxyng pushed a commit to ollama/ollama that referenced this pull request Aug 11, 2025
It was implemented upstream:
ggml-org/llama.cpp#14741

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
mxyng pushed a commit to ollama/ollama that referenced this pull request Aug 12, 2025
It was implemented upstream:
ggml-org/llama.cpp#14741

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
mxyng pushed a commit to ollama/ollama that referenced this pull request Aug 12, 2025
It was implemented upstream:
ggml-org/llama.cpp#14741

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
mxyng added a commit to ollama/ollama that referenced this pull request Aug 14, 2025
* TEMPORARY: Update the llama.cpp upstream to my fork's Granite Four branch

This will be redone once my branch is merged upstream in llama.cpp

* feat: Update all patches

There are a number that are no longer needed at all:

- 0003-embeddings: Embeddings entirely overhauled on master
- 0008-ensure-KV-cache-is-fully-defragmented: KV caching entirely
    overhauled on master
- 0019-metal-add-mean-kernel-14267: Merged upstream
- 0020-CUDA-add-mean-operation-14313: Merged upstream

* feat: Sync llama.cpp and ggml

* fix: Update rsync-filter for all moved/new/removed files

* fix: Add files missing from sync

* fix: Update ggml rsync-filter for new ggml-cpu/arch subdirs

* fix: Add ggml files missing from sync

* fix: Narrow llama.cpp rsync-filter to not include mtmd main tool cpp files

* fix: Remove mtmd main cpp files

* fix: Add missing include in sampling_ext.cpp

* fix: Update llama.go to use mtmd instead of clip/llava

* fix: Add patch for mtmd_input_text

* chore: Ignore *.patched in the patch directory

* fix: Fix support for arch-specific ggml-cpu source files with new arrangement

In ggml-org/llama.cpp#13892, all arch-specific
implementations were split out into a nested tree structure under
ggml-cpu/arch. This conflicts with standard CGO layout where all
arch-specific source files are expected to live in the same directory as
the parent go module and use suffixes based on GOOS and GOARCH. As such,
there were really two options for getting this to work:

1. Add a patch on top of the GGML sync to rearrange the files to match the
GO layout convention
2. Use CGO directives to conditionally include the nested source files in
the compilation units

This commit does (2) in order to minimize the set of changes needed on top
of the upstream file layout. To get this to work, there are two key things
needed:

1. In cpu.go, #cgo directives are added to explicitly set __${GOARCH}__ in
the preprocessor directives
2. In arch-impls.c|cpp, use an #ifdef | #elif defined | #endif chain to
explicitly include the .c|.cpp files for the given architecture from the
nested directory

* fix: Use mtmd_helper to correctly load the bitmap for the image

* fix: Apply patch for mtmd_text_input

* fix: Add missing stb to llama.cpp rsync-filter

* fix: Add sync'ed stb vendored header

* fix: Use c++17 and include vendor for go wrapper modules

* fix: Update patch 0015 for upstream implementation of uuid

* feat: Bump to the latest tip of the branch

* fix: Update patches for bump

* feat: Bump back to the cenral repo and point at the latest master

This includes granite 4 and a number of other model architectures!

* fix: Revert changes to ggml export GPU UUID patch

* fix: Add patch for GGML_VERSION and GGML_COMMIT constants

* feat: Sync all patched code

* build: Include cmake/common.cmake in ggml sync

* build: Add top-level include for GNUINstallDirs in CMakeLists.txt

This is used to populate CMAKE_INSTALL_BINDIR

* fix: Add a patch to avoid power throttling API on non-msvc windows builds

* fix: Sync patch changes for ggml-cpu.c

* feat: Bump llama.cpp to 4a4f42

This picks up support for Kimi K2 and PLaMO-2

* feat: Sync llama.cpp

* fix: Handle multi-chunk image encodings from mtmd

* fix: Re-number patches after merge with `main`

* feat: Bump to 41e78c in the makefile

* fix: Fix Solar and argsort/copy patches after bump

* fix: Remove Gemma3n CUDA Graphs patch

It was implemented upstream:
ggml-org/llama.cpp#14741

* feat: Sync llama.cpp / ggml after latest bump

* build: Remove unnecessary CFLAGS definitions in cpu.go

* fix: Remove unnecessary additions in the rsync-filter

* fix: Remove unused vendored code for chat template parsing

* Revert "fix: Remove Gemma3n CUDA Graphs patch"

This reverts commit d724cac.

* fix: Update 0020 CUDA Graphs for gemma3n to keep both llama.cpp and ollama fixes

#11195 (comment)

* fix: Sync ggml-cuda.cu after keeping both style cuda graph fixes for gemma3n

* unwind mxfp4 patch

Prepare to bump ggml with their impl for mxfp4

* bump

* fix windows build error

* Convert tensors at load time

Repack the mxfp4 tensors as ggmls kernels expect them to be.

* convert mlp bf16 to f32

* buffer the conversion better

* reshape earlier

* openai swiglu

* add ids

* split qkv, gate_up

* fix nested alt tags

* fast attention

* remove debug messages

* fix lint

* remove redundant test

* remap values only if source/target are different

* add back i32->i32 copy

* refactor cpu quants

* clean up vendor

* update patch instructions

* clean up patches

* remove webgpu

* update mem

* also handle gpt-oss

* revert convert changes

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Daniel Hiltgen <[email protected]>
rick-github pushed a commit to rick-github/ollama that referenced this pull request Aug 20, 2025
…aph execution (ollama#11525)

* Enable CUDA Graphs for gemma3n.

Similar to
ggml-org/llama.cpp#14741,
though ollama has a slightly different model graph
than llama.cpp which requires different workaround
checks.

* Remove residual check by reshaping differently in gemma3n model

This should make the heuristics more robust
rick-github pushed a commit to rick-github/ollama that referenced this pull request Aug 20, 2025
* TEMPORARY: Update the llama.cpp upstream to my fork's Granite Four branch

This will be redone once my branch is merged upstream in llama.cpp

* feat: Update all patches

There are a number that are no longer needed at all:

- 0003-embeddings: Embeddings entirely overhauled on master
- 0008-ensure-KV-cache-is-fully-defragmented: KV caching entirely
    overhauled on master
- 0019-metal-add-mean-kernel-14267: Merged upstream
- 0020-CUDA-add-mean-operation-14313: Merged upstream

* feat: Sync llama.cpp and ggml

* fix: Update rsync-filter for all moved/new/removed files

* fix: Add files missing from sync

* fix: Update ggml rsync-filter for new ggml-cpu/arch subdirs

* fix: Add ggml files missing from sync

* fix: Narrow llama.cpp rsync-filter to not include mtmd main tool cpp files

* fix: Remove mtmd main cpp files

* fix: Add missing include in sampling_ext.cpp

* fix: Update llama.go to use mtmd instead of clip/llava

* fix: Add patch for mtmd_input_text

* chore: Ignore *.patched in the patch directory

* fix: Fix support for arch-specific ggml-cpu source files with new arrangement

In ggml-org/llama.cpp#13892, all arch-specific
implementations were split out into a nested tree structure under
ggml-cpu/arch. This conflicts with standard CGO layout where all
arch-specific source files are expected to live in the same directory as
the parent go module and use suffixes based on GOOS and GOARCH. As such,
there were really two options for getting this to work:

1. Add a patch on top of the GGML sync to rearrange the files to match the
GO layout convention
2. Use CGO directives to conditionally include the nested source files in
the compilation units

This commit does (2) in order to minimize the set of changes needed on top
of the upstream file layout. To get this to work, there are two key things
needed:

1. In cpu.go, #cgo directives are added to explicitly set __${GOARCH}__ in
the preprocessor directives
2. In arch-impls.c|cpp, use an #ifdef | #elif defined | #endif chain to
explicitly include the .c|.cpp files for the given architecture from the
nested directory

* fix: Use mtmd_helper to correctly load the bitmap for the image

* fix: Apply patch for mtmd_text_input

* fix: Add missing stb to llama.cpp rsync-filter

* fix: Add sync'ed stb vendored header

* fix: Use c++17 and include vendor for go wrapper modules

* fix: Update patch 0015 for upstream implementation of uuid

* feat: Bump to the latest tip of the branch

* fix: Update patches for bump

* feat: Bump back to the cenral repo and point at the latest master

This includes granite 4 and a number of other model architectures!

* fix: Revert changes to ggml export GPU UUID patch

* fix: Add patch for GGML_VERSION and GGML_COMMIT constants

* feat: Sync all patched code

* build: Include cmake/common.cmake in ggml sync

* build: Add top-level include for GNUINstallDirs in CMakeLists.txt

This is used to populate CMAKE_INSTALL_BINDIR

* fix: Add a patch to avoid power throttling API on non-msvc windows builds

* fix: Sync patch changes for ggml-cpu.c

* feat: Bump llama.cpp to 4a4f42

This picks up support for Kimi K2 and PLaMO-2

* feat: Sync llama.cpp

* fix: Handle multi-chunk image encodings from mtmd

* fix: Re-number patches after merge with `main`

* feat: Bump to 41e78c in the makefile

* fix: Fix Solar and argsort/copy patches after bump

* fix: Remove Gemma3n CUDA Graphs patch

It was implemented upstream:
ggml-org/llama.cpp#14741

* feat: Sync llama.cpp / ggml after latest bump

* build: Remove unnecessary CFLAGS definitions in cpu.go

* fix: Remove unnecessary additions in the rsync-filter

* fix: Remove unused vendored code for chat template parsing

* Revert "fix: Remove Gemma3n CUDA Graphs patch"

This reverts commit d724cac.

* fix: Update 0020 CUDA Graphs for gemma3n to keep both llama.cpp and ollama fixes

ollama#11195 (comment)

* fix: Sync ggml-cuda.cu after keeping both style cuda graph fixes for gemma3n

* unwind mxfp4 patch

Prepare to bump ggml with their impl for mxfp4

* bump

* fix windows build error

* Convert tensors at load time

Repack the mxfp4 tensors as ggmls kernels expect them to be.

* convert mlp bf16 to f32

* buffer the conversion better

* reshape earlier

* openai swiglu

* add ids

* split qkv, gate_up

* fix nested alt tags

* fast attention

* remove debug messages

* fix lint

* remove redundant test

* remap values only if source/target are different

* add back i32->i32 copy

* refactor cpu quants

* clean up vendor

* update patch instructions

* clean up patches

* remove webgpu

* update mem

* also handle gpt-oss

* revert convert changes

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Daniel Hiltgen <[email protected]>
rick-github pushed a commit to rick-github/ollama that referenced this pull request Aug 20, 2025
* TEMPORARY: Update the llama.cpp upstream to my fork's Granite Four branch

This will be redone once my branch is merged upstream in llama.cpp

* feat: Update all patches

There are a number that are no longer needed at all:

- 0003-embeddings: Embeddings entirely overhauled on master
- 0008-ensure-KV-cache-is-fully-defragmented: KV caching entirely
    overhauled on master
- 0019-metal-add-mean-kernel-14267: Merged upstream
- 0020-CUDA-add-mean-operation-14313: Merged upstream

* feat: Sync llama.cpp and ggml

* fix: Update rsync-filter for all moved/new/removed files

* fix: Add files missing from sync

* fix: Update ggml rsync-filter for new ggml-cpu/arch subdirs

* fix: Add ggml files missing from sync

* fix: Narrow llama.cpp rsync-filter to not include mtmd main tool cpp files

* fix: Remove mtmd main cpp files

* fix: Add missing include in sampling_ext.cpp

* fix: Update llama.go to use mtmd instead of clip/llava

* fix: Add patch for mtmd_input_text

* chore: Ignore *.patched in the patch directory

* fix: Fix support for arch-specific ggml-cpu source files with new arrangement

In ggml-org/llama.cpp#13892, all arch-specific
implementations were split out into a nested tree structure under
ggml-cpu/arch. This conflicts with standard CGO layout where all
arch-specific source files are expected to live in the same directory as
the parent go module and use suffixes based on GOOS and GOARCH. As such,
there were really two options for getting this to work:

1. Add a patch on top of the GGML sync to rearrange the files to match the
GO layout convention
2. Use CGO directives to conditionally include the nested source files in
the compilation units

This commit does (2) in order to minimize the set of changes needed on top
of the upstream file layout. To get this to work, there are two key things
needed:

1. In cpu.go, #cgo directives are added to explicitly set __${GOARCH}__ in
the preprocessor directives
2. In arch-impls.c|cpp, use an #ifdef | #elif defined | #endif chain to
explicitly include the .c|.cpp files for the given architecture from the
nested directory

* fix: Use mtmd_helper to correctly load the bitmap for the image

* fix: Apply patch for mtmd_text_input

* fix: Add missing stb to llama.cpp rsync-filter

* fix: Add sync'ed stb vendored header

* fix: Use c++17 and include vendor for go wrapper modules

* fix: Update patch 0015 for upstream implementation of uuid

* feat: Bump to the latest tip of the branch

* fix: Update patches for bump

* feat: Bump back to the cenral repo and point at the latest master

This includes granite 4 and a number of other model architectures!

* fix: Revert changes to ggml export GPU UUID patch

* fix: Add patch for GGML_VERSION and GGML_COMMIT constants

* feat: Sync all patched code

* build: Include cmake/common.cmake in ggml sync

* build: Add top-level include for GNUINstallDirs in CMakeLists.txt

This is used to populate CMAKE_INSTALL_BINDIR

* fix: Add a patch to avoid power throttling API on non-msvc windows builds

* fix: Sync patch changes for ggml-cpu.c

* feat: Bump llama.cpp to 4a4f42

This picks up support for Kimi K2 and PLaMO-2

* feat: Sync llama.cpp

* fix: Handle multi-chunk image encodings from mtmd

* fix: Re-number patches after merge with `main`

* feat: Bump to 41e78c in the makefile

* fix: Fix Solar and argsort/copy patches after bump

* fix: Remove Gemma3n CUDA Graphs patch

It was implemented upstream:
ggml-org/llama.cpp#14741

* feat: Sync llama.cpp / ggml after latest bump

* build: Remove unnecessary CFLAGS definitions in cpu.go

* fix: Remove unnecessary additions in the rsync-filter

* fix: Remove unused vendored code for chat template parsing

* Revert "fix: Remove Gemma3n CUDA Graphs patch"

This reverts commit d724cac.

* fix: Update 0020 CUDA Graphs for gemma3n to keep both llama.cpp and ollama fixes

ollama#11195 (comment)

* fix: Sync ggml-cuda.cu after keeping both style cuda graph fixes for gemma3n

* unwind mxfp4 patch

Prepare to bump ggml with their impl for mxfp4

* bump

* fix windows build error

* Convert tensors at load time

Repack the mxfp4 tensors as ggmls kernels expect them to be.

* convert mlp bf16 to f32

* buffer the conversion better

* reshape earlier

* openai swiglu

* add ids

* split qkv, gate_up

* fix nested alt tags

* fast attention

* remove debug messages

* fix lint

* remove redundant test

* remap values only if source/target are different

* add back i32->i32 copy

* refactor cpu quants

* clean up vendor

* update patch instructions

* clean up patches

* remove webgpu

* update mem

* also handle gpt-oss

* revert convert changes

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Daniel Hiltgen <[email protected]>
sjsone pushed a commit to sjsone/ollama that referenced this pull request Aug 23, 2025
* TEMPORARY: Update the llama.cpp upstream to my fork's Granite Four branch

This will be redone once my branch is merged upstream in llama.cpp

* feat: Update all patches

There are a number that are no longer needed at all:

- 0003-embeddings: Embeddings entirely overhauled on master
- 0008-ensure-KV-cache-is-fully-defragmented: KV caching entirely
    overhauled on master
- 0019-metal-add-mean-kernel-14267: Merged upstream
- 0020-CUDA-add-mean-operation-14313: Merged upstream

* feat: Sync llama.cpp and ggml

* fix: Update rsync-filter for all moved/new/removed files

* fix: Add files missing from sync

* fix: Update ggml rsync-filter for new ggml-cpu/arch subdirs

* fix: Add ggml files missing from sync

* fix: Narrow llama.cpp rsync-filter to not include mtmd main tool cpp files

* fix: Remove mtmd main cpp files

* fix: Add missing include in sampling_ext.cpp

* fix: Update llama.go to use mtmd instead of clip/llava

* fix: Add patch for mtmd_input_text

* chore: Ignore *.patched in the patch directory

* fix: Fix support for arch-specific ggml-cpu source files with new arrangement

In ggml-org/llama.cpp#13892, all arch-specific
implementations were split out into a nested tree structure under
ggml-cpu/arch. This conflicts with standard CGO layout where all
arch-specific source files are expected to live in the same directory as
the parent go module and use suffixes based on GOOS and GOARCH. As such,
there were really two options for getting this to work:

1. Add a patch on top of the GGML sync to rearrange the files to match the
GO layout convention
2. Use CGO directives to conditionally include the nested source files in
the compilation units

This commit does (2) in order to minimize the set of changes needed on top
of the upstream file layout. To get this to work, there are two key things
needed:

1. In cpu.go, #cgo directives are added to explicitly set __${GOARCH}__ in
the preprocessor directives
2. In arch-impls.c|cpp, use an #ifdef | #elif defined | #endif chain to
explicitly include the .c|.cpp files for the given architecture from the
nested directory

* fix: Use mtmd_helper to correctly load the bitmap for the image

* fix: Apply patch for mtmd_text_input

* fix: Add missing stb to llama.cpp rsync-filter

* fix: Add sync'ed stb vendored header

* fix: Use c++17 and include vendor for go wrapper modules

* fix: Update patch 0015 for upstream implementation of uuid

* feat: Bump to the latest tip of the branch

* fix: Update patches for bump

* feat: Bump back to the cenral repo and point at the latest master

This includes granite 4 and a number of other model architectures!

* fix: Revert changes to ggml export GPU UUID patch

* fix: Add patch for GGML_VERSION and GGML_COMMIT constants

* feat: Sync all patched code

* build: Include cmake/common.cmake in ggml sync

* build: Add top-level include for GNUINstallDirs in CMakeLists.txt

This is used to populate CMAKE_INSTALL_BINDIR

* fix: Add a patch to avoid power throttling API on non-msvc windows builds

* fix: Sync patch changes for ggml-cpu.c

* feat: Bump llama.cpp to 4a4f42

This picks up support for Kimi K2 and PLaMO-2

* feat: Sync llama.cpp

* fix: Handle multi-chunk image encodings from mtmd

* fix: Re-number patches after merge with `main`

* feat: Bump to 41e78c in the makefile

* fix: Fix Solar and argsort/copy patches after bump

* fix: Remove Gemma3n CUDA Graphs patch

It was implemented upstream:
ggml-org/llama.cpp#14741

* feat: Sync llama.cpp / ggml after latest bump

* build: Remove unnecessary CFLAGS definitions in cpu.go

* fix: Remove unnecessary additions in the rsync-filter

* fix: Remove unused vendored code for chat template parsing

* Revert "fix: Remove Gemma3n CUDA Graphs patch"

This reverts commit d724cac.

* fix: Update 0020 CUDA Graphs for gemma3n to keep both llama.cpp and ollama fixes

ollama#11195 (comment)

* fix: Sync ggml-cuda.cu after keeping both style cuda graph fixes for gemma3n

* unwind mxfp4 patch

Prepare to bump ggml with their impl for mxfp4

* bump

* fix windows build error

* Convert tensors at load time

Repack the mxfp4 tensors as ggmls kernels expect them to be.

* convert mlp bf16 to f32

* buffer the conversion better

* reshape earlier

* openai swiglu

* add ids

* split qkv, gate_up

* fix nested alt tags

* fast attention

* remove debug messages

* fix lint

* remove redundant test

* remap values only if source/target are different

* add back i32->i32 copy

* refactor cpu quants

* clean up vendor

* update patch instructions

* clean up patches

* remove webgpu

* update mem

* also handle gpt-oss

* revert convert changes

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Daniel Hiltgen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants