Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 91 additions & 34 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -907,40 +907,70 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]]
else if (extra->split_dim == 0) {
int n_interleave = 1;
if (auto it = k_map.find(tensor->type); it != k_map.end()) n_interleave = it->second;
//if (tensor->type >= GGML_TYPE_Q4_0_R8) {
// GGML_ABORT("Dim 0 copy of row-interleaved quants is not supported yet");
//}
auto tt = ggml_internal_get_type_traits(tensor->type);
std::vector<char> host_buffer;
GGML_ASSERT(ggml_is_contiguous(tensor));
int nrows = ggml_nrows(tensor);
auto bs = tt.blck_size;
auto ts = tt.type_size;
auto row_size = ggml_row_size(tensor->type, tensor->ne[0]);
int ne = 0;
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) continue;
GGML_ASSERT(split->ne[1]%n_interleave == 0);
ggml_cuda_set_device(i);
GGML_ASSERT(split->type == tensor->type);
GGML_ASSERT((int)ggml_nrows(split) == nrows);
GGML_ASSERT(split->ne[0] % bs == 0);
auto source_offset = n_interleave*(tt.row_meta_size + (ne / bs) * ts);
auto split_row_size = ggml_row_size(split->type, split->ne[0]);
if (host_buffer.size() < nrows*split_row_size) host_buffer.resize(nrows*split_row_size);
for (int64_t i02 = 0; i02 < split->ne[2]; ++i02) {
void * extra_ptr;
memcpy(&extra_ptr, tensor->op_params, sizeof(extra_ptr));
if (extra_ptr) {
auto & ranges = *(const std::vector<std::vector<std::pair<int,int>>> *)extra_ptr;
GGML_ASSERT(extra->n_device == int(ranges.size()));
GGML_ASSERT(tensor->ne[2]*tensor->ne[3] == 1);
GGML_ASSERT(n_interleave == 1);
GGML_ASSERT(tt.row_meta_size == 0);
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) {
GGML_ASSERT(ranges[i].empty());
continue;
}
GGML_ASSERT(!ranges[i].empty());
GGML_ASSERT((int)ggml_nrows(split) == nrows);
auto split_row_size = ggml_row_size(split->type, split->ne[0]);
if (host_buffer.size() < nrows*split_row_size) host_buffer.resize(nrows*split_row_size);
auto dst = host_buffer.data();
for (int64_t i01 = 0; i01 < split->ne[1]; i01 += n_interleave) {
auto dst = host_buffer.data() + (i02*split->ne[1] + i01)*split_row_size;
auto src = (const char *)data + i02*tensor->nb[2] + i01*tensor->nb[1];
if (tt.row_meta_size > 0) {
memcpy(dst, src, tt.row_meta_size*n_interleave);
for (auto & p : ranges[i]) {
GGML_ASSERT(p.first % bs == 0);
GGML_ASSERT(p.second % bs == 0);
auto src = (const char *)data + i01*tensor->nb[1] + (p.first/bs)*ts;
auto size = (p.second/bs)*ts;
memcpy(dst, src, size);
dst += size;
}
}
ggml_cuda_set_device(i);
CUDA_CHECK(cudaMemcpyAsync(split->data, host_buffer.data(), nrows*split_row_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
}
} else {
int ne = 0;
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) continue;
GGML_ASSERT(split->ne[1]%n_interleave == 0);
ggml_cuda_set_device(i);
GGML_ASSERT(split->type == tensor->type);
GGML_ASSERT((int)ggml_nrows(split) == nrows);
GGML_ASSERT(split->ne[0] % bs == 0);
auto source_offset = n_interleave*(tt.row_meta_size + (ne / bs) * ts);
auto split_row_size = ggml_row_size(split->type, split->ne[0]);
if (host_buffer.size() < nrows*split_row_size) host_buffer.resize(nrows*split_row_size);
for (int64_t i02 = 0; i02 < split->ne[2]; ++i02) {
for (int64_t i01 = 0; i01 < split->ne[1]; i01 += n_interleave) {
auto dst = host_buffer.data() + (i02*split->ne[1] + i01)*split_row_size;
auto src = (const char *)data + i02*tensor->nb[2] + i01*tensor->nb[1];
if (tt.row_meta_size > 0) {
memcpy(dst, src, tt.row_meta_size*n_interleave);
}
memcpy(dst + tt.row_meta_size*n_interleave, src + source_offset, n_interleave*(split_row_size - tt.row_meta_size));
}
memcpy(dst + tt.row_meta_size*n_interleave, src + source_offset, n_interleave*(split_row_size - tt.row_meta_size));
}
CUDA_CHECK(cudaMemcpyAsync(split->data, host_buffer.data(), nrows*split_row_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
ne += split->ne[0];
}
CUDA_CHECK(cudaMemcpyAsync(split->data, host_buffer.data(), nrows*split_row_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
ne += split->ne[0];
}
}
else if (extra->split_dim == 1) {
Expand All @@ -965,16 +995,43 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_set_tensor([[maybe_unused]]
} else {
int n_interleave = 1;
if (auto it = k_map.find(tensor->type); it != k_map.end()) n_interleave = it->second;
size_t cur_offset = 0;
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) continue;
GGML_ASSERT(split->ne[1]%n_interleave == 0);
ggml_cuda_set_device(i);
auto size = ggml_nbytes(split);
const char * buf_host = (const char *)data + cur_offset;
CUDA_CHECK(cudaMemcpyAsync(split->data, buf_host, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
cur_offset += size;
void * extra_ptr;
memcpy(&extra_ptr, tensor->op_params, sizeof(extra_ptr));
if (extra_ptr) {
auto & ranges = *(const std::vector<std::vector<std::pair<int,int>>> *)extra_ptr;
GGML_ASSERT(extra->n_device == int(ranges.size()));
GGML_ASSERT(tensor->ne[2]*tensor->ne[3] == 1);
GGML_ASSERT(n_interleave == 1);
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) {
GGML_ASSERT(ranges[i].empty());
continue;
}
GGML_ASSERT(!ranges[i].empty());
ggml_cuda_set_device(i);
auto dst = (char *)split->data;
for (auto & p : ranges[i]) {
GGML_ASSERT(p.first >= 0 && p.first < tensor->ne[1]);
GGML_ASSERT(p.second >= 0 && p.first + p.second <= tensor->ne[1]);
auto src = (const char *)data + p.first*tensor->nb[1];
auto size = p.second*tensor->nb[1];
CUDA_CHECK(cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
dst += size;
}
}
} else {
size_t cur_offset = 0;
for (int i = 0; i < extra->n_device; ++i) {
auto split = extra->splits[i];
if (!split) continue;
GGML_ASSERT(split->ne[1]%n_interleave == 0);
ggml_cuda_set_device(i);
auto size = ggml_nbytes(split);
const char * buf_host = (const char *)data + cur_offset;
CUDA_CHECK(cudaMemcpyAsync(split->data, buf_host, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
cur_offset += size;
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/delta-net.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ __global__ void delta_net_recurrent_f32(
}

}
__syncthreads();
// Copy the final state to its destination
for (int i = 0; i < HEAD_DIM/num_warps; ++i) {
int col = num_warps*i + col_idx_0;
Expand Down
69 changes: 3 additions & 66 deletions src/llama-build-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4437,7 +4437,6 @@ ggml_cgraph * llm_build_context::build_qwen3next() {
ggml_tensor * cur = nullptr;

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

GGML_ASSERT(model.layers[il].attn_norm != nullptr);
GGML_ASSERT(model.layers[il].attn_post_norm != nullptr);
Expand All @@ -4455,27 +4454,7 @@ ggml_cgraph * llm_build_context::build_qwen3next() {


if (hparams.is_recurrent(il)) {
int idx = model.default_layer_device[il];
if (inpL->op == GGML_OP_REDUCE) {
if (kv_self.s_l[il]) {
// This shouldn't be necessary, but just in case.
int idx_s_l = ggml_backend_sched_get_backend_idx(lctx.sched, kv_self.s_l[il]->buffer);
if (idx_s_l >= 0) idx = idx_s_l;
}
if (inpL->src[idx]) {
inpL->view_src = inpL->src[idx];
}
}
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
cur = llm_build_norm(ctx0, inpL, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
cur = delta.build_layer_attn_linear(ctx0, gf, cur, il, cb);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
cur = delta.build_layer_attn_linear(ctx0, gf, inpL, il == n_layer - 1 ? inp_out_ids : nullptr, il, cb);
} else {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, false);
Expand Down Expand Up @@ -4545,28 +4524,7 @@ ggml_cgraph * llm_build_context::build_qwen35moe() {
for (int il = 0; il < n_layer; ++il) {

if (hparams.is_recurrent(il)) {
ggml_tensor * inpSA = inpL;
int idx = model.default_layer_device[il];
if (inpL->op == GGML_OP_REDUCE) {
if (kv_self.s_l[il]) {
// This shouldn't be necessary, but just in case.
int idx_s_l = ggml_backend_sched_get_backend_idx(lctx.sched, kv_self.s_l[il]->buffer);
if (idx_s_l >= 0) idx = idx_s_l;
}
if (inpL->src[idx]) {
inpL->view_src = inpL->src[idx];
}
}
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
cur = llm_build_norm(ctx0, inpL, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
cur = delta.build_layer_attn_linear(ctx0, gf, cur, il, cb);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
cur = delta.build_layer_attn_linear(ctx0, gf, inpL, il == n_layer - 1 ? inp_out_ids : nullptr, il, cb);
} else {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, true);
Expand Down Expand Up @@ -4625,28 +4583,7 @@ ggml_cgraph * llm_build_context::build_qwen35() {
for (int il = 0; il < n_layer; ++il) {

if (hparams.is_recurrent(il)) {
ggml_tensor * inpSA = inpL;
int idx = model.default_layer_device[il];
if (inpL->op == GGML_OP_REDUCE) {
if (kv_self.s_l[il]) {
// This shouldn't be necessary, but just in case.
int idx_s_l = ggml_backend_sched_get_backend_idx(lctx.sched, kv_self.s_l[il]->buffer);
if (idx_s_l >= 0) idx = idx_s_l;
}
if (inpL->src[idx]) {
inpL->view_src = inpL->src[idx];
}
}
auto norm = model.layers[il].attn_norm->extra ? ((ggml_split_tensor_t *)model.layers[il].attn_norm->extra)->splits[idx] : model.layers[il].attn_norm;
cur = llm_build_norm(ctx0, inpL, hparams, norm, nullptr, LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
cur = delta.build_layer_attn_linear(ctx0, gf, cur, il, cb);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
cur = ggml_add(ctx0, cur, inpSA);
cb(cur, "attn_residual", il);
cur = delta.build_layer_attn_linear(ctx0, gf, inpL, il == n_layer - 1 ? inp_out_ids : nullptr, il, cb);
} else {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, il == n_layer - 1 ? inp_out_ids : nullptr, nullptr,
KQ_mask, nullptr, nullptr, KQ_scale, 0.0f, 0, il, true, false, true, false, true);
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct llama_kv_cache {

std::vector<llama_split_tensor> split_k_l;
std::vector<llama_split_tensor> split_v_l;
std::vector<llama_split_tensor> split_s_l;

std::vector<struct ggml_context *> ctxs;
std::vector<ggml_backend_buffer_t> bufs;
Expand Down
Loading