Skip to content

ggml: backend-agnostic tensor parallelism#19378

Open
JohannesGaessler wants to merge 34 commits intoggml-org:masterfrom
JohannesGaessler:ggml-meta-backend-8
Open

ggml: backend-agnostic tensor parallelism#19378
JohannesGaessler wants to merge 34 commits intoggml-org:masterfrom
JohannesGaessler:ggml-meta-backend-8

Conversation

@JohannesGaessler
Copy link
Contributor

@JohannesGaessler JohannesGaessler commented Feb 5, 2026

This PR adds support for backend-agnostic tensor parallelism, enabled via specifying --split-mode tensor. This is done by adding a new "meta" backend that internally wraps multiple "simple" backends but can be used in the same way as a regular ggml backend.

ggml Backend Interface Changes

This PR extends the ggml backend interface with some new functions for tensor copying:

  • set_tensor_2d_async/get_tensor_2d_async which are equivalent to memcpy2DAsync in CUDA. This is not needed for the computation of the meta backend itself but rather for setting/getting weights or the output. Currently not implemented, as a workaround the one-dimensional version is used in a loop.
  • shfl_tensor_async to allow two ggml backends to exchange two tensors and to synchronize on the completion of the exchange. As a fallback cpy_tensor_async can be used but this has a higher latency because the copy in one direction can only start once the one in the other direction has finished. Needed for a generic AllReduce between ggml backends. Implemented.
  • allreduce_tensor_async to allow ggml backends to specify a backend-specific way to do an AllReduce operation. Intended to be used for NCCL support in cooperation with @gaugarg-nv . Not yet implemented.

@slaren please provide feedback regarding whether you agree that these operations should be in the ggml backend interface. For context, all of them are optional and can use existing operations as a fallback.

Meta Backend Implementation Details

The meta backend implements an entire ggml backend stack starting from a meta device. The meta device is created from multiple simple backend devices as well as a function to determine how the data should be split across devices ("split states"). Backend buffer types, buffers, and backends are created as per usual. When calling ggml_backend_graph_compute the code infers the split states of the nodes in the compute graph based on the split states assigned for the weights/kv cache. The basic pattern is to make all tensors mirrored by default. For the weight matrices, do a split in dimension 1, then a split in dimension 0, then an AllReduce. For a transformer this means two AllReduce operations, one after the attention and one after the FFN. The attention is effectively split by dimension 0, which equates to a split by attention head.

An generic AllReduce operation is performed in the meta backend by splitting the graph into subgraphs. After a subgraph is executed, call shfl_tensor_async to make backends exchange partial results, and then have them execute auxiliary graphs that contain only a GGML_ADD operation to combine the results.

The memory allocation for the compute graph is rather tricky - the way I solved it is to allocate the memory for the meta backend as per usual and to then transplant the calculated addresses relative to the backend buffer base pointer to the underlying simple backends. Because the simple tensors only require a fraction of the full memory this yields correct results, though it does result in overallocation for the compute graphs. For the weights/kv cache the memory allocation for the meta backend is done via a new function ggml_backend_meta_alloc_ctx_tensors_from_buft to prevent duplicated weights (which are much larger in size). I'm not yet sure what the best approach will be long-term, I think the graph allocation code in ggml-alloc.c will need to be adjusted.

Current Issues/Limitations

  • Only 1 or 2 GPUs are supported. Note: 1 GPU is not actually any faster, this is only useful for testing whether the code works correctly.
  • All GPUs must have an equal share of the data, --tensor-split has no effect.
  • Only dense models are supported. The LLaMA 3 models seem to be working correctly, I have not yet tested others.
  • Support for llama_params_fit is not implemented so the context size has to be set manually.
  • Without FlashAttention the code will probably crash because some transition between split states is not yet implemented.
  • In principle all backends should work. CUDA does in my testing, Vulkan however does not. I think there may be some issues with deadlock between the GPUs. @jeffbolznv @0cc4m if you could take a look it would be appreciated.
  • Memory for the ggml contexts is being overallocated.
  • Performance is (presumably) still suboptimal vs. NCCL.
  • I'm currently using tensor names to determine how to split individual tensors. I think it would be preferable to use some sort of enum instead (which we already seem to have for loading tensors). This should also be used for llama_params_fit.
  • I'm currently setting ggml_tensor::data to dummy values since that is what is checked in ggml-alloc.c to determine whether or not a tensor is considered allocated. This dummy value should never actually be used for any computations but I don't consider this a good solution.
  • I'm not putting meta devices into the ggml backend device registry (which I think is correct).

Performance

LLaMA 3 on 2x RTX 4090
model test t/s -sm layer t/s -sm row t/s -sm tensor
llama 8B Q4_0 pp512 12550.97 2997.41 6305.68
llama 8B Q4_0 pp2048 18788.11 2970.83 6300.12
llama 8B Q4_0 tg128 175.43 67.00 101.98
llama 8B Q4_0 pp512 @ d32768 5099.65 2200.50 4137.73
llama 8B Q4_0 pp2048 @ d32768 7925.21 2242.84 4337.54
llama 8B Q4_0 tg128 @ d32768 96.78 49.76 102.81
llama 8B Q4_0 pp512 @ d65536 3154.69 1748.05 3139.40
llama 8B Q4_0 pp2048 @ d65536 4996.19 1806.82 3404.70
llama 8B Q4_0 tg128 @ d65536 67.11 40.27 83.62
llama 8B Q4_0 pp512 @ d131072 1800.72 1243.57 2152.01
llama 8B Q4_0 pp2048 @ d131072 2867.83 1294.61 2238.01
llama 8B Q4_0 tg128 @ d131072 41.66 29.19 59.53
llama 8B F16 pp512 10578.54 1591.55 5950.44
llama 8B F16 pp2048 15890.45 1581.04 6005.96
llama 8B F16 tg128 60.07 46.32 70.64
llama 8B F16 pp512 @ d32768 4745.19 1330.02 4032.19
llama 8B F16 pp2048 @ d32768 **7329.28 ** 1347.02 4279.02
llama 8B F16 tg128 @ d32768 47.03 38.02 64.63
llama 8B F16 pp512 @ d65536 3033.10 1154.54 3100.10
llama 8B F16 pp2048 @ d65536 4782.02 1176.88 3341.03
llama 8B F16 tg128 @ d65536 38.72 32.19 56.48
llama 8B F16 pp512 @ d131072 1735.63 905.39 2090.03
llama 8B F16 pp2048 @ d131072 2782.42 936.55 2288.65
llama 8B F16 tg128 @ d131072 28.60 24.79 43.32
llama 70B Q3_K - Small pp512 1287.50 582.81 1072.80
llama 70B Q3_K - Small pp2048 1954.83 590.45 1069.28
llama 70B Q3_K - Small tg128 27.97 20.36 29.65
llama 70B Q3_K - Small pp512 @ d32768 776.80 458.17 812.77
llama 70B Q3_K - Small pp2048 @ d32768 1185.82 459.43 824.99
llama 70B Q3_K - Small tg128 @ d32768 20.85 16.19 29.39

Generally speaking it can be observed that parallelizing larger models has better performance than parallelizing smaller models. Similarly, parallelizing the model becomes more worthwhile as the context depth increases. This makes sense as both of these result in a larger workload per GPU vs. the overhead from parallelization. Token generation benefits more from parallelization than prompt processing because the amount of data that needs to be transferred between GPUs is proportional to batch size - long-term it may make sense to implement support for FP16/BF16 compute types which would count the I/O in half vs. FP32. For pp512 pipeline parallelism is effectively disabled while for pp2048 it's enabled. With pipeline parallelism -sm layer is still faster than -sm tensor even at high context depths.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend examples ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs OpenCL Issues specific to the OpenCL backend IBM zDNN issues specific to IBM zDNN Accelerator labels Feb 5, 2026
@jeffbolznv
Copy link
Contributor

In principle all backends should work. CUDA does in my testing, Vulkan however does not. I think there may be some issues with deadlock between the GPUs. @jeffbolznv @0cc4m if you could take a look it would be appreciated.

I'm not seeing a deadlock, just a crash in the driver with an invalid descriptor. I ran llama-bench.exe -fa 1 -p 512 -n 0 -m c:\models\llama-2-7b.Q4_0.gguf -sm tensor

Validation Error: [ VUID-VkDescriptorBufferInfo-offset-00340 ] | MessageID = 0xc23dafe5
vkUpdateDescriptorSets(): pDescriptorWrites[0].pBufferInfo[2].offset (18362368) is greater than or equal to buffer size (8388608).
The Vulkan spec states: offset must be less than the size of buffer (https://docs.vulkan.org/spec/latest/chapters/descriptorsets.html#VUID-VkDescriptorBufferInfo-offset-00340)

-		dst_buf	{buffer=shared_ptr {buffer={m_buffer=0x0000be00000000be {...} } device_memory={m_deviceMemory=0x0000bf00000000bf {...} } ...} [0x00000003 strong refs] [make_shared] ...}	vk_subbuffer
+		buffer	shared_ptr {buffer={m_buffer=0x0000be00000000be {...} } device_memory={m_deviceMemory=0x0000bf00000000bf {...} } ...} [0x00000003 strong refs] [make_shared]	std::shared_ptr<vk_buffer_struct>
		offset	0x0000000001183000	unsigned __int64
		size	0x0000000000800000	unsigned __int64
-		dst	0x0000018f98716fd0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f98385c80 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...} ...}	ggml_tensor *
		type	GGML_TYPE_F32 (0x00000000)	ggml_type
+		buffer	0x0000018f98385c80 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...}	ggml_backend_buffer *
+		ne	0x0000018f98716fe0 {0x0000000000001000, 0x0000000000000200, 0x0000000000000001, 0x0000000000000001}	__int64[0x00000004]
+		nb	0x0000018f98717000 {0x0000000000000004, 0x0000000000004000, 0x0000000000800000, 0x0000000000800000}	unsigned __int64[0x00000004]
		op	GGML_OP_ADD (0x00000002)	ggml_op
+		op_params	0x0000018f98717024 {0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, ...}	int[0x00000010]
		flags	0x00000010	int
+		src	0x0000018f98717068 {0x00000191a1b1d1b0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f97d4f130 {iface=...} ...}, ...}	ggml_tensor *[0x0000000a]
-		view_src	0x00000191a1b1d1b0 {type=GGML_TYPE_F32 (0x00000000) buffer=0x0000018f97d4f130 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...} ...}	ggml_tensor *
		type	GGML_TYPE_F32 (0x00000000)	ggml_type
+		buffer	0x0000018f97d4f130 {iface={free_buffer=0x00007ffbc6647c50 {ggml-vulkan.dll!ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer *)} ...} ...}	ggml_backend_buffer *
+		ne	0x00000191a1b1d1c0 {0x0000000000001000, 0x0000000000000200, 0x0000000000000001, 0x0000000000000001}	__int64[0x00000004]
+		nb	0x00000191a1b1d1e0 {0x0000000000000004, 0x0000000000004000, 0x0000000000800000, 0x0000000000800000}	unsigned __int64[0x00000004]
		op	GGML_OP_MUL_MAT (0x0000001d)	ggml_op
+		op_params	0x00000191a1b1d204 {0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, ...}	int[0x00000010]
		flags	0x00000010	int
+		src	0x00000191a1b1d248 {0x0000018f9f29d940 {type=GGML_TYPE_Q4_0 (0x00000002) buffer=0x0000018f97d4f670 {...} ...}, ...}	ggml_tensor *[0x0000000a]
+		view_src	0x0000000000000000 <NULL>	ggml_tensor *
		view_offs	0x0000000000000000	unsigned __int64
		data	0x0000000001184000	void *
+		name	0x00000191a1b1d2b0 "attn_out-0"	char[0x00000040]
		extra	0x0000000000000000	void *
+		padding	0x00000191a1b1d2f8 ""	char[0x00000008]
		view_offs	0x0000000000000000	unsigned __int64
		data	0x0000000001184000	void *
+		name	0x0000018f987170d0 "attn_out-0 (view)"	char[0x00000040]
		extra	0x0000000000000000	void *
+		padding	0x0000018f98717118 ""	char[0x00000008]

It seems like tensor->data (and tensor->view_src->data) are too large. I haven't debugged further.

@jacekpoplawski
Copy link
Contributor

works for me on 2x3090 for llama 3 8B and Mistral Nemo 12B

on Devstral I have OOM (expected because model size?)
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 30720.00 MiB on device 0: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA0 buffer of size 32212254720

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bdaa8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp", line=119, fmt=0x7ffff77bd88f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff773565b in ggml_backend_buffer_get_size (buffer=0x0) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:119
#7  0x00007ffff7743b0e in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x555557dc41f0, buft=0x55555b032b18) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:638
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x555557dc41f0, buft=0x55555b032b18) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6f3d6c2 in llama_kv_cache::llama_kv_cache (this=0x555556363770, model=..., type_k=GGML_TYPE_F16, type_v=GGML_TYPE_F16, v_trans=false, offload=true, unified=true, kv_size=393216,
    n_seq_max=4, n_pad=1, n_swa=0, swa_type=LLAMA_SWA_TYPE_NONE, filter=..., reuse=...) at /home/jacek/git/llama.cpp/src/llama-kv-cache.cpp:190
#10 0x00007ffff700faec in llama_model::create_memory (this=0x555556357ea0, params=..., cparams=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7617
#11 0x00007ffff6ed008f in llama_context::llama_context (this=0x555557a65260, model=..., params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:274
#12 0x00007ffff6edd53c in llama_init_from_model (model=0x555556357ea0, params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:3046
#13 0x00005555558f4e24 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1183
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636eba0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79b0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0c8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)
but on MInistral 14B too
ggml_backend_cuda_buffer_type_alloc_buffer: allocating 20480.00 MiB on device 0: cudaMalloc failed: out of memory
alloc_tensor_range: failed to allocate CUDA0 buffer of size 21474836480

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bdaa8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp", line=119, fmt=0x7ffff77bd88f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff773565b in ggml_backend_buffer_get_size (buffer=0x0) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend.cpp:119
#7  0x00007ffff7743b0e in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x555557e339e0, buft=0x55555abb40b8) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:638
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x555557e339e0, buft=0x55555abb40b8) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6f3d6c2 in llama_kv_cache::llama_kv_cache (this=0x55555abb9470, model=..., type_k=GGML_TYPE_F16, type_v=GGML_TYPE_F16, v_trans=false, offload=true, unified=true, kv_size=262144,
    n_seq_max=4, n_pad=1, n_swa=0, swa_type=LLAMA_SWA_TYPE_NONE, filter=..., reuse=...) at /home/jacek/git/llama.cpp/src/llama-kv-cache.cpp:190
#10 0x00007ffff700faec in llama_model::create_memory (this=0x555556357e50, params=..., cparams=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7617
#11 0x00007ffff6ed008f in llama_context::llama_context (this=0x555557a65140, model=..., params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:274
#12 0x00007ffff6edd53c in llama_init_from_model (model=0x555556357e50, params=...) at /home/jacek/git/llama.cpp/src/llama-context.cpp:3046
#13 0x00005555558f4e24 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1183
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636eb30, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79b0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0c8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)
qwen 4B has a different issue
/home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:386: GGML_ASSERT(ne[split_dim] % n_simple_bufs == 0) failed

Thread 1 "llama-server" received signal SIGABRT, Aborted.
Download failed: Invalid argument.  Continuing without source file ./nptl/./nptl/pthread_kill.c.
__pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
warning: 44     ./nptl/pthread_kill.c: No such file or directory
(gdb) bt
#0  __pthread_kill_implementation (threadid=<optimized out>, signo=6, no_tid=0) at ./nptl/pthread_kill.c:44
#1  __pthread_kill_internal (threadid=<optimized out>, signo=6) at ./nptl/pthread_kill.c:89
#2  __GI___pthread_kill (threadid=<optimized out>, signo=signo@entry=6) at ./nptl/pthread_kill.c:100
#3  0x00007fffedc4579e in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4  0x00007fffedc288cd in __GI_abort () at ./stdlib/abort.c:73
#5  0x00007ffff7719ec5 in ggml_abort (file=0x7ffff77bedd8 "/home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp", line=386, fmt=0x7ffff77beb9f "GGML_ASSERT(%s) failed")
    at /home/jacek/git/llama.cpp/ggml/src/ggml.c:256
#6  0x00007ffff7741e84 in ggml_backend_meta_buffer_init_tensor (buffer=0x55555a4421d0, tensor=0x55555a46df70) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:386
#7  0x00007ffff7743a52 in ggml_backend_meta_alloc_ctx_tensors_from_buft (ctx=0x55555a4422f0, buft=0x55555a426188) at /home/jacek/git/llama.cpp/ggml/src/ggml-backend-meta.cpp:632
#8  0x00007ffff7734a4a in ggml_backend_alloc_ctx_tensors_from_buft (ctx=0x55555a4422f0, buft=0x55555a426188) at /home/jacek/git/llama.cpp/ggml/src/ggml-alloc.c:1245
#9  0x00007ffff6fff452 in llama_model::load_tensors (this=0x555556357e50, ml=...) at /home/jacek/git/llama.cpp/src/llama-model.cpp:7055
#10 0x00007ffff6e557fc in llama_model_load (fname="/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", splits=std::vector of length 0, capacity 0, model=..., params=...)
    at /home/jacek/git/llama.cpp/src/llama.cpp:876
#11 0x00007ffff6e56ce3 in llama_model_load_from_file_impl (path_model="/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", splits=std::vector of length 0, capacity 0, params=...)
    at /home/jacek/git/llama.cpp/src/llama.cpp:1069
#12 0x00007ffff6e56fe3 in llama_model_load_from_file (path_model=0x555556367610 "/mnt/models2/Qwen/Qwen_Qwen3-4B-Q8_0.gguf", params=...) at /home/jacek/git/llama.cpp/src/llama.cpp:1096
#13 0x00005555558f46c9 in common_init_result::common_init_result (this=0x55555624dc70, params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1107
#14 0x00005555558f50d0 in common_init_from_params (params=...) at /home/jacek/git/llama.cpp/common/common.cpp:1215
#15 0x0000555555710999 in server_context_impl::load_model (this=0x55555636ead0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:625
#16 0x00005555556e9ac8 in server_context::load_model (this=0x7fffffff79d0, params=...) at /home/jacek/git/llama.cpp/tools/server/server-context.cpp:2856
#17 0x00005555556177bd in main (argc=7, argv=0x7fffffffe0e8) at /home/jacek/git/llama.cpp/tools/server/server.cpp:248
(gdb)

@DocShotgun
Copy link
Contributor

Interesting. If the CPU backend is able to be virtualized into multiple devices as described here, would it be possible to allow multiple NUMA nodes to be parallelized?

@gopinath87607
Copy link

will this pr solve if we use multiple rpc connected to gpu with cpu and when we use cpu moe flag ?

@JohannesGaessler
Copy link
Contributor Author

@jacekpoplawski the combination of --split-mode tensor and llama_params_fit is not implemented so you'll have to set the context size manually if you didn't already.

@DocShotgun longer-term I intend to also enable this code for better NUMA support though I'm not yet sure what to do in terms of hardware for development. Originally I had intended to buy 1.5 TiB of DDR5 RAM and 2 EPYC CPUs but at the current prices that would be financially irresponsible of me to do.

@gopinath87607 I don't understand what you mean.

@FullstackSensei
Copy link

@DocShotgun longer-term I intend to also enable this code for better NUMA support though I'm not yet sure what to do in terms of hardware for development. Originally I had intended to buy 1.5 TiB of DDR5 RAM and 2 EPYC CPUs but at the current prices that would be financially irresponsible of me to do.

If you're looking for a DDR4 platform, I might be able to help with that, but unfortunately, not so much with RAM for the system. I'm also in Germany.

Wouldn't mind giving you access to my systems if you want. Have two dual Xeon systems one with P40s and the other with Mi50s.

Would be very happy to help either way.

@ggerganov
Copy link
Member

ggerganov commented Feb 6, 2026

Started doing some initial tests to get familiar with the changes. I'm using virtual Metal devices and things appear to be mostly working - e.g. seeing graph execution on both devices, llama-perplexity produces the same result with 2 devices.

However, the following command does not produce identical results on each run:

GGML_METAL_DEVICES=2 ./bin/llama-completion -m ~/models/llama-3.1-8b/ggml-model-f16.gguf -no-cnv -p "I believe the meaning of life is" -n 32 --top-k 1 -sm tensor

I think this either means that there could be a problem with the fallback for the missing backend API, or that I could have made an error in the implementation of the events and the async copy in Metal. Will investigate more.

For context, all of them are optional and can use existing operations as a fallback.

In the meantime, @JohannesGaessler do you confirm that the command above has deterministic results on your end? Also, if is it deterministic with the fallback calls?

@JohannesGaessler
Copy link
Contributor Author

JohannesGaessler commented Feb 6, 2026

Generally speaking you cannot expect bit-for-bit identical results if you split the computation across multiple virtual devices. The order in which floats are summed up will be different which will in turn change the rounding error. If I run llama-perplexity using LLaMA 3 8b f16 I get 6.2560 for -sm layer and 6.2556 for -sm tensor. It's of course still possible that there are bugs on top of that, I would just say changes to the results are expected. Long-term I think we should test this new code in the same way we test ggml op fusion in test-backend-ops.

If you use -sm tensor with only a single GPU the executed ops should be the exact same and the result should be bit-for-bit identical to -sm layer.

@JohannesGaessler
Copy link
Contributor Author

Sorry, I think I misread your post. If you are saying that the results are not deterministic with 2 virtual GPUs but they are with 1 GPU then that I think is indicative of a bug w.r.t. the synchronization.

@ggerganov
Copy link
Member

Yes, I understand that 1GPU vs 2GPU will not be bit-for-bit identical. What I mean is that in my test, running the command with 2 GPUs several times produces non-deterministic results from one run to the other:

GGML_METAL_DEVICES=2 ./bin/llama-completion -m ~/models/llama-3.1-8b/ggml-model-f16.gguf -no-cnv -p "I believe the meaning of life is" -n 32 --top-k 1 -sm tensor

# run 1
I believe the meaning of life is to find your gift. The purpose of life is to give it away.
I believe that the meaning of life is to find your gift. The purpose of life

# run 2
I believe the meaning of life is to find your gift. The purpose of life is to give it away. To give it away, you have to find it. [end of text]

# run 3
I believe the meaning of life is to be happy. I believe that happiness is the only thing that matters. I believe that happiness is the only thing that matters. I believe that happiness is the

Sorry, I think I misread your post. If you are saying that the results are not deterministic with 2 virtual GPUs but they are with 1 GPU then that I think is indicative of a bug w.r.t. the synchronization.

Yes, seems like a synchronization issue. Was wondering if you observe it on your end with and without the fallback backend API. This will give me indication where to look for the issue.

@JohannesGaessler
Copy link
Contributor Author

With the command you posted (minus the GGML_METAL_DEVICES=2) I get deterministic results on 2x RTX 4090, both with and without the fallback for shfl_tensor_async. The output of llama-perplexity is bit-for-bit identical with vs. without the fallback. So presumably either there is a bug with Metal synchronization or I made assumptions about the behavior of ggml backends that are correct for CUDA but not universally.

JohannesGaessler and others added 15 commits March 8, 2026 18:00
…expose the same type. This allows using pinned memory instead of pageable memory for CUDA.

Fix compilation errors.
* Fix crash with Qwen-30B-A3B Q4_0

Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation.

* Decide block size based on tensor quantization type
KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset.
@JohannesGaessler
Copy link
Contributor Author

JohannesGaessler commented Mar 8, 2026

I rebased on top of the end-to-end tests added via test-llama-archs. It seems the following model architectures currently have functional support for -sm tensor (even if the performance is not good):

Model list
  • llama
  • llama4
  • deci
  • baichuan
  • refact
  • stablelm
  • qwen2
  • qwen2moe
  • qwen2vl
  • qwen3
  • qwen3moe
  • qwen3vl
  • qwen3vlmoe
  • plamo
  • orion
  • internlm2
  • minicpm
  • gemma
  • gemma2
  • gemma3
  • starcoder2
  • xverse
  • command-r
  • cohere2
  • olmo
  • arctic
  • deepseek
  • glm4moe
  • jais2
  • nemotron
  • exaone
  • exaone4
  • exaone-moe
  • granite
  • granite
  • granitemoe
  • bailingmoe
  • dots1
  • arcee
  • ernie4_5
  • ernie4_5-moe
  • hunyuan-moe
  • hunyuan-dense
  • smollm3
  • gpt-oss
  • lfm2
  • lfm2moe
  • smallthinker
  • seed_oss
  • grovemoe
  • apertus
  • pangu-embedded
  • mistral3
  • paddleocr
  • mimo2
  • maincoder

@jacekpoplawski
Copy link
Contributor

jacekpoplawski commented Mar 8, 2026

I originally assumed that without p2p it wouldn't make much sense, but it works

(also it works on 3 GPUs which if I am correct is not possible in vllm)

big speedup on tg, on Qwen3 32B the speed almost doubles

I wasn't able to run llama-bench with tensor on Qwen3.5 35B A3B for some reason yet (but Qwen 3 30B A3B runs) same problem with GLM 4.7 Flash (these archs are not supported yet)

plots for LLama 8B, Qwen 8B/14B/32B/30B
ggml_cuda_init: found 3 CUDA devices:
  Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
  Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes

AMD Ryzen Threadripper 1920X
X399 Taichi

Llama 3.1 8B Q8

g p

Qwen 3 8B Q8

g p

Qwen 3 14B Q8

g p

Qwen 3 32B Q8

g p

Qwen 3 30B A3B Q8

g p

@github-actions github-actions bot added the testing Everything test related label Mar 8, 2026
const size_t j_src = j_src_0 == 0 ? j_dst : (j_src_0 - (j_src_0 <= j_dst ? 1 : 0));
auto & bcj_src = backend_ctx->backend_configs[j_src];

ggml_backend_tensor_copy_async(bcj_dst.backend, bcj_src.backend, dst_views[j_dst*n_backends], src_views[j_dst*n_backends + j_src]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has a bug due to which there is a device mismatch in bcj_src.backend and src_views[j_dst*n_backends + j_src]. This is because the tensors are pushed back in the src_views in a different order at line 1014. Either we should fix the index at which we write the value, or read from the right index.

PR JohannesGaessler#11 fixes this by modifying the read index.

Perf improvement: RTX 6000 Pro Blackwell, Qwen3-4b-Q4_K_M:

  Before After Speed-up
pp512 12573.87 14227.6 1.13
tg128 123.82 164 1.32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I agree that this is logically correct. I am not observing a difference in performance though:

GPU Model Test t/s 82786ff t/s 827b63d Speedup
2x RTX 4090 llama 8B Q4_0 pp512 11634.83 11663.71 1.00
2x RTX 4090 llama 8B Q4_0 tg128 174.82 175.42 1.00

There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies
@JohannesGaessler
Copy link
Contributor Author

I agree with the assessment of @gaugarg-nv that the generic fallback for AllReduce should be reverted to the previous version. I had implemented it in a coupled way with other changes so I pushed it but there are correctness issues and the performance for those cases where it works correctly is not good.

@JohannesGaessler
Copy link
Contributor Author

I pushed a MoE optimization that delays the AllReduce after the FFN until after the results per expert have been summedd up on each GPU, effectively reducing I/O by n_expert_used. For models like GPT-OSS 20b with n_expert_used == 4 this squeezes out a few % but for Qwen 3 30b a3b with n_expert_used == 8 the difference is relatively large.

Performance
GPU Model Microbatch size Test t/s 0840004 t/s ae0334f Speedup
2x RTX 4090 gpt-oss 20B MXFP4 MoE 1 pp512 306.14 320.77 1.05
2x RTX 4090 gpt-oss 20B MXFP4 MoE 2 pp512 417.18 449.57 1.08
2x RTX 4090 gpt-oss 20B MXFP4 MoE 4 pp512 650.72 622.83 0.96
2x RTX 4090 gpt-oss 20B MXFP4 MoE 8 pp512 876.68 799.74 0.91
2x RTX 4090 gpt-oss 20B MXFP4 MoE 16 pp512 1431.11 1357.02 0.95
2x RTX 4090 gpt-oss 20B MXFP4 MoE 32 pp512 2259.36 1942.63 0.86
2x RTX 4090 gpt-oss 20B MXFP4 MoE 64 pp512 3237.20 3326.37 1.03
2x RTX 4090 gpt-oss 20B MXFP4 MoE 128 pp512 4105.24 4332.97 1.06
2x RTX 4090 gpt-oss 20B MXFP4 MoE 256 pp512 5334.81 5723.92 1.07
2x RTX 4090 gpt-oss 20B MXFP4 MoE 512 pp512 6279.34 6804.04 1.08
2x RTX 4090 qwen3moe 30B.A3B Q4_0 1 pp512 156.56 198.42 1.27
2x RTX 4090 qwen3moe 30B.A3B Q4_0 2 pp512 226.16 268.37 1.19
2x RTX 4090 qwen3moe 30B.A3B Q4_0 4 pp512 318.08 405.73 1.28
2x RTX 4090 qwen3moe 30B.A3B Q4_0 8 pp512 388.83 560.83 1.44
2x RTX 4090 qwen3moe 30B.A3B Q4_0 16 pp512 790.16 921.51 1.17
2x RTX 4090 qwen3moe 30B.A3B Q4_0 32 pp512 1143.91 1424.43 1.25
2x RTX 4090 qwen3moe 30B.A3B Q4_0 64 pp512 1517.10 2011.26 1.33
2x RTX 4090 qwen3moe 30B.A3B Q4_0 128 pp512 2103.22 3394.58 1.61
2x RTX 4090 qwen3moe 30B.A3B Q4_0 256 pp512 2621.08 4885.02 1.86
2x RTX 4090 qwen3moe 30B.A3B Q4_0 512 pp512 3029.21 6561.17 2.17

@gaugarg-nv
Copy link
Contributor

I pushed a MoE optimization that delays the AllReduce after the FFN until after the results per expert have been summedd up on each GPU, effectively reducing I/O by n_expert_used. For models like GPT-OSS 20b with n_expert_used == 4 this squeezes out a few % but for Qwen 3 30b a3b with n_expert_used == 8 the difference is relatively large.

I can confirm that this improves perf for Qwen3-235B-A22B Q4_0 as well on 2x RTX 6000 Pro Blackwell:

  Before After Speed-up
pp512 1583.47 2157.62 1.36
tg128 68.65 72.52 1.06

@gaugarg-nv
Copy link
Contributor

@JohannesGaessler I added a change that improves scaling of FFN-down GEMVs: JohannesGaessler#13

Please see if there is anything you would like me to test. Also, this change is unrelated to TP itself, but improves kernel performance for small-K values. Let me know if you would like me to file this on master branch.

      ae0334f PR Speed-up
2xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 pp512 2165.37 2167.83 1.00
2xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 tg128 71.51 75.37 1.05
2xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 pp512 8357.29 8359.93 1.00
2xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 tg128 182.1 194.26 1.07
4xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 pp512 2367.91 2342.61 0.99
4xRTX 6000 Pro BW Qwen3-235B-A22B-Q4_0 tg128 66.05 71.5 1.08
4xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 pp512 8408.73 8415.25 1.00
4xRTX 6000 Pro BW Qwen3-30B-A3B-Q4_0 tg128 155.57 162.79 1.05

@JohannesGaessler
Copy link
Contributor Author

I would suggest you make a PR to master since these changes are independent and this PR is already getting quite large.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) Ascend NPU issues specific to Ascend NPUs examples ggml changes relating to the ggml tensor library for machine learning IBM zDNN issues specific to IBM zDNN Accelerator Nvidia GPU Issues specific to Nvidia GPUs OpenCL Issues specific to the OpenCL backend SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language testing Everything test related Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.