Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
245dac6
ggml: backend-agnostic tensor parallelism
JohannesGaessler Jan 14, 2026
8ca0d42
support for GPT-OSS, Qwen 3 MoE
JohannesGaessler Feb 6, 2026
e815e54
partial Vulkan fix
JohannesGaessler Feb 6, 2026
743151a
add support for 4/8 GPUs
JohannesGaessler Feb 7, 2026
1b70f14
unconditional peer access
JohannesGaessler Feb 7, 2026
e9f261a
re-use buffers + ggml contexts
JohannesGaessler Feb 8, 2026
e8a9d84
fix output pattern
JohannesGaessler Feb 9, 2026
c732203
NCCL support
JohannesGaessler Feb 10, 2026
60beaeb
GGML: HIP: add RCCL support
IMbackK Feb 11, 2026
755c3ff
Remove shfl and AllReduce from backend interface
JohannesGaessler Feb 11, 2026
5c777fb
move allocation workaround out of ggml-alloc.c
JohannesGaessler Feb 11, 2026
620d7d4
2d tensor set/get support
JohannesGaessler Feb 11, 2026
74bb93a
Fix the seg fault without NCCL
gaugarg-nv Feb 12, 2026
560cd4e
Apply suggestion from @JohannesGaessler
JohannesGaessler Feb 12, 2026
b04befb
support for tensor dims % n_devs != 0
JohannesGaessler Feb 11, 2026
39052af
fix view_offs scaling
JohannesGaessler Feb 13, 2026
ed054c0
arbitrary num. of GPUs/tensor split
JohannesGaessler Feb 13, 2026
331a1f0
fix compilation
JohannesGaessler Feb 13, 2026
05fc6b3
better granularity estimate
JohannesGaessler Feb 13, 2026
578396d
Support device-specific host buffer types if all underlying backends …
gaugarg-nv Feb 16, 2026
fbe5450
partial Qwen 3 Next support
JohannesGaessler Feb 19, 2026
10f101a
Fix qwen3 30b (#8)
gaugarg-nv Feb 25, 2026
ca1c0fa
Fix crashes due to KV cache serialization (#9)
gaugarg-nv Feb 28, 2026
a03cab1
metal : fix build (#7)
ggerganov Feb 28, 2026
0e8eba8
static memory allocations, fix usage count
JohannesGaessler Mar 3, 2026
fa381da
fix tensor granularity
JohannesGaessler Mar 6, 2026
aae6584
more even memory distribution
JohannesGaessler Mar 6, 2026
28e4790
use BF16 for allreduce
JohannesGaessler Mar 7, 2026
25f5dd9
rebase fixup
JohannesGaessler Mar 8, 2026
e3f9879
fix tensor names
JohannesGaessler Mar 8, 2026
82786ff
better error message for unsupported architectures
JohannesGaessler Mar 8, 2026
3d9a86b
Fix device mismatch during scatter of allReduce. (#11)
gaugarg-nv Mar 9, 2026
0840004
Enable the previous allreduce implementation. It is better in both pe…
gaugarg-nv Mar 14, 2026
ae0334f
delay AllReduce for Moe for less I/O
JohannesGaessler Mar 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,19 +2338,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_N_GPU_LAYERS"));
add_opt(common_arg(
{"-sm", "--split-mode"}, "{none,layer,row}",
{"-sm", "--split-mode"}, "{none,layer,row,tensor}",
"how to split the model across multiple GPUs, one of:\n"
"- none: use one GPU only\n"
"- layer (default): split layers and KV across GPUs\n"
"- row: split rows across GPUs",
"- layer (default): split layers and KV across GPUs (pipelined)\n"
"- row: split weight across GPUs by rows (parallelized)\n"
"- tensor: split weights and KV across GPUs (parallelized)",
[](common_params & params, const std::string & value) {
std::string arg_next = value;
if (arg_next == "none") {
if (value == "none") {
params.split_mode = LLAMA_SPLIT_MODE_NONE;
} else if (arg_next == "layer") {
} else if (value == "layer") {
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
} else if (value == "row") {
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else if (value == "tensor") {
params.split_mode = LLAMA_SPLIT_MODE_TENSOR;
} else {
throw std::invalid_argument("invalid value");
}
Expand Down
4 changes: 4 additions & 0 deletions ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 7)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)
if(GIT_EXE)
# Get current git commit hash
Expand Down Expand Up @@ -203,12 +205,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON)
set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING
"ggml: cuda link binary compression mode; requires cuda 12.8+")
set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size")

option(GGML_HIP "ggml: use HIP" OFF)
option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF)
option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON)
Expand Down
36 changes: 36 additions & 0 deletions ggml/cmake/FindNCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# cmake/FindNCCL.cmake

# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead.

find_path(NCCL_INCLUDE_DIR
NAMES nccl.h
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
PATH_SUFFIXES include
)

find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda
PATH_SUFFIXES lib lib64
)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL
DEFAULT_MSG
NCCL_LIBRARY NCCL_INCLUDE_DIR
)

if(NCCL_FOUND)
set(NCCL_LIBRARIES ${NCCL_LIBRARY})
set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR})

if(NOT TARGET NCCL::NCCL)
add_library(NCCL::NCCL UNKNOWN IMPORTED)
set_target_properties(NCCL::NCCL PROPERTIES
IMPORTED_LOCATION "${NCCL_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}"
)
endif()
endif()

mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
82 changes: 73 additions & 9 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ extern "C" {
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);

// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst);

//
// Backend (stream)
Expand All @@ -83,13 +83,17 @@ extern "C" {
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend);

GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);

// "offset" refers to the offset in tensor->data for setting/getting data
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);

GGML_API void ggml_backend_synchronize(ggml_backend_t backend);

Expand All @@ -109,7 +113,7 @@ extern "C" {
// the copy is performed after all the currently queued operations in backend_src
// backend_dst will wait for the copy to complete before performing other operations
// automatic fallback to sync copy if async is not supported
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst);
GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);

GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);

Expand All @@ -135,7 +139,9 @@ extern "C" {
// integrated GPU device using host memory
GGML_BACKEND_DEVICE_TYPE_IGPU,
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
GGML_BACKEND_DEVICE_TYPE_ACCEL
GGML_BACKEND_DEVICE_TYPE_ACCEL,
// "meta" device wrapping multiple other devices for tensor parallelism
GGML_BACKEND_DEVICE_TYPE_META,
};

// functionality supported by the device
Expand Down Expand Up @@ -196,7 +202,9 @@ extern "C" {

// Common functions that may be obtained using ggml_backend_reg_get_proc_address

// Split buffer type for tensor parallelism
// AllReduce operation for tensor parallelism (meta backend)
typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);
// Split buffer type for tensor parallelism (old)
typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split);
// Set the number of threads for the backend
typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads);
Expand All @@ -211,6 +219,62 @@ extern "C" {
};
typedef struct ggml_backend_feature * (*ggml_backend_get_features_t)(ggml_backend_reg_t reg);

//
// Meta backend
//

#define GGML_BACKEND_META_MAX_DEVICES 16

enum ggml_backend_meta_split_axis {
// tensor split by tensor dimensions:
GGML_BACKEND_SPLIT_AXIS_0 = 0,
GGML_BACKEND_SPLIT_AXIS_1 = 1,
GGML_BACKEND_SPLIT_AXIS_2 = 2,
GGML_BACKEND_SPLIT_AXIS_3 = 3,

GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends
GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum

// for internal bookkeeping only:
GGML_BACKEND_SPLIT_AXIS_NONE = 98,
GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99,
};
GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis);

struct ggml_backend_meta_split_state {
enum ggml_backend_meta_split_axis axis;
int64_t ne[GGML_BACKEND_META_MAX_DEVICES];
};

// function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible:
typedef struct ggml_backend_meta_split_state (*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata);

GGML_API bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev);
GGML_API size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev);
GGML_API ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index);

// create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this:
GGML_API ggml_backend_dev_t ggml_backend_meta_device(
ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud);

GGML_API bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft);
GGML_API ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index);

GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf);
GGML_API size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf);
GGML_API ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index);
GGML_API struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index);

GGML_API bool ggml_backend_is_meta(ggml_backend_t backend);
GGML_API size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend);
GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index);

GGML_API struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync);

// temporary workaround to statically allocate tensors from a context in a deduplicated way:
GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);

//
// Backend registry
//
Expand Down
3 changes: 3 additions & 0 deletions ggml/include/ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
// device buffer
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);

// conduct allreduce operation between devices
GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends);

// split tensor buffer that splits matrices by rows across multiple devices
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ add_library(ggml-base
ggml.cpp
ggml-alloc.c
ggml-backend.cpp
ggml-backend-meta.cpp
ggml-opt.cpp
ggml-threading.cpp
ggml-threading.h
Expand Down
12 changes: 10 additions & 2 deletions ggml/src/ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

// ggml-backend internal header

#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "ggml.h"

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -49,6 +51,10 @@ extern "C" {
void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) 2d data copies
void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);

// (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported)
bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
// clear the entire buffer
Expand Down Expand Up @@ -90,8 +96,10 @@ extern "C" {
void (*free)(ggml_backend_t backend);

// (optional) asynchronous tensor data access
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data);
bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);

// (optional) complete all pending operations (required if the backend supports async operations)
Expand Down
Loading