Skip to content

Commit 4c67f3d

Browse files
wqerrewetwdanbevodrlingallozauram17an
authored
up (#5)
* model-conversion : add trust_remote_code for orig model run [no ci] (ggml-org#16751) This commit add the trust_remote_code=True argument when loading models using AutoConfig, AutoTokenizer, and AutoModelForCausalLM for the run original model script. The motivation for this is that some models require custom code to be loaded properly, and setting trust_remote_code=True avoids a prompt asking for user confirmation: ```console (venv) $ make causal-run-original-model The repository /path/to/model contains custom code which must be executed to correctly load the model. You can inspect the repository content at /path/to/model. Do you wish to run the custom code? [y/N] N ``` Having this as the default seems like a safe choice as we have to clone or download the models we convert and would be expecting to run any custom code they have. * webui: support q URL parameter (ggml-org#16728) * webui: support q URL parameter Fixes ggml-org#16722 I’ve checked that it works with Firefox’s AI tools * webui: apply suggestions from code review Co-authored-by: Aleksander Grygier <[email protected]> * chore: update webui static build --------- Co-authored-by: Aleksander Grygier <[email protected]> * CUDA: use CUB for arbitary size argsort (ggml-org#16754) * ggml: fix CUDA grid launch condition for large block_nums.y in binbcast (ggml-org#16742) * Fix CUDA grid launch condition for large block_nums.y * add backend ops test * reduce test repetitions * convert : avoid dequantizing mxfp4 for GPT-OSS (ggml-org#16756) * vulkan: Optimize SSM_SCAN (ggml-org#16645) * vulkan: delete dead code (ggml-org#16732) ggml_vk_create_buffer_temp is not used anywhere, and it is the only caller for ggml_vk_pool_malloc. Signed-off-by: Giuseppe Scrivano <[email protected]> * model : set res->t_embd in PLaMo2 models (ggml-org#16766) --------- Signed-off-by: Giuseppe Scrivano <[email protected]> Co-authored-by: Daniel Bevenius <[email protected]> Co-authored-by: Florian Badie <[email protected]> Co-authored-by: Aleksander Grygier <[email protected]> Co-authored-by: Aman Gupta <[email protected]> Co-authored-by: leejet <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: Jeff Bolz <[email protected]> Co-authored-by: Giuseppe Scrivano <[email protected]> Co-authored-by: Shunta Saito <[email protected]>
1 parent 0f34149 commit 4c67f3d

File tree

12 files changed

+169
-109
lines changed

12 files changed

+169
-109
lines changed

convert_hf_to_gguf.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8943,6 +8943,13 @@ def set_vocab(self):
89438943
class GptOssModel(TextModel):
89448944
model_arch = gguf.MODEL_ARCH.GPT_OSS
89458945

8946+
# TODO: remove once MXFP4 is supported more generally
8947+
def dequant_model(self):
8948+
quant_config = self.hparams.get("quantization_config")
8949+
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
8950+
return
8951+
return super().dequant_model()
8952+
89468953
def transform_nibble_layout(self, tensor):
89478954
assert tensor.dtype == torch.uint8
89488955
assert tensor.shape[-1] == 16

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def fn(_m, input, output):
138138
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
139139
)
140140

141-
config = AutoConfig.from_pretrained(model_path)
141+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
142142

143143
print("Model type: ", config.model_type)
144144
print("Vocab size: ", config.vocab_size)
@@ -148,8 +148,8 @@ def fn(_m, input, output):
148148
print("EOS token id: ", config.eos_token_id)
149149

150150
print("Loading model and tokenizer using AutoTokenizer:", model_path)
151-
tokenizer = AutoTokenizer.from_pretrained(model_path)
152-
config = AutoConfig.from_pretrained(model_path)
151+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
152+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
153153

154154
if unreleased_model_name:
155155
model_name_lower = unreleased_model_name.lower()
@@ -171,7 +171,7 @@ def fn(_m, input, output):
171171
exit(1)
172172
else:
173173
model = AutoModelForCausalLM.from_pretrained(
174-
model_path, device_map="auto", offload_folder="offload"
174+
model_path, device_map="auto", offload_folder="offload", trust_remote_code=True
175175
)
176176

177177
for name, module in model.named_modules():

ggml/src/ggml-cuda/argsort.cu

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,81 @@
11
#include "argsort.cuh"
22

3+
#ifdef GGML_CUDA_USE_CUB
4+
# include <cub/cub.cuh>
5+
using namespace cub;
6+
#endif // GGML_CUDA_USE_CUB
7+
8+
static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
9+
const int col = blockIdx.x * blockDim.x + threadIdx.x;
10+
const int row = blockIdx.y;
11+
12+
if (col < ncols && row < nrows) {
13+
indices[row * ncols + col] = col;
14+
}
15+
}
16+
17+
static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
18+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
19+
if (idx <= nrows) {
20+
offsets[idx] = idx * ncols;
21+
}
22+
}
23+
24+
#ifdef GGML_CUDA_USE_CUB
25+
static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
26+
const float * x,
27+
int * dst,
28+
const int ncols,
29+
const int nrows,
30+
ggml_sort_order order,
31+
cudaStream_t stream) {
32+
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
33+
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
34+
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
35+
36+
int * temp_indices = temp_indices_alloc.get();
37+
float * temp_keys = temp_keys_alloc.get();
38+
int * d_offsets = offsets_alloc.get();
39+
40+
static const int block_size = 256;
41+
const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
42+
init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
43+
44+
const dim3 offset_grid((nrows + block_size - 1) / block_size);
45+
init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows);
46+
47+
cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream);
48+
49+
size_t temp_storage_bytes = 0;
50+
51+
if (order == GGML_SORT_ORDER_ASC) {
52+
DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
53+
temp_indices, dst, // values (indices)
54+
ncols * nrows, nrows, // num items, num segments
55+
d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
56+
stream);
57+
} else {
58+
DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
59+
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
60+
sizeof(float) * 8, stream);
61+
}
62+
63+
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
64+
void * d_temp_storage = temp_storage_alloc.get();
65+
66+
if (order == GGML_SORT_ORDER_ASC) {
67+
DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
68+
ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
69+
stream);
70+
} else {
71+
DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
72+
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
73+
0, sizeof(float) * 8, stream);
74+
}
75+
}
76+
#endif // GGML_CUDA_USE_CUB
77+
78+
// Bitonic sort implementation
379
template<typename T>
480
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
581
T tmp = a;
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
65141
return n;
66142
}
67143

68-
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
144+
static void argsort_f32_i32_cuda_bitonic(const float * x,
145+
int * dst,
146+
const int ncols,
147+
const int nrows,
148+
ggml_sort_order order,
149+
cudaStream_t stream) {
69150
// bitonic sort requires ncols to be power of 2
70151
const int ncols_pad = next_power_of_2(ncols);
71152

@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
77158
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
78159

79160
if (order == GGML_SORT_ORDER_ASC) {
80-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
161+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
162+
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
81163
} else if (order == GGML_SORT_ORDER_DESC) {
82-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
164+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
165+
<<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
83166
} else {
84167
GGML_ABORT("fatal error");
85168
}
@@ -100,5 +183,18 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
100183

101184
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
102185

103-
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
186+
#ifdef GGML_CUDA_USE_CUB
187+
const int ncols_pad = next_power_of_2(ncols);
188+
const size_t shared_mem = ncols_pad * sizeof(int);
189+
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
190+
191+
if (shared_mem > max_shared_mem || ncols > 1024) {
192+
ggml_cuda_pool & pool = ctx.pool();
193+
argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
194+
} else {
195+
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
196+
}
197+
#else
198+
argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
199+
#endif
104200
}

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
272272
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
273273
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
274274

275-
if (block_nums.z > 65535) {
275+
if (block_nums.z > 65535 || block_nums.y > 65535) {
276276
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
277277
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
278278
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3642,8 +3642,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
36423642
case GGML_OP_SUM:
36433643
return ggml_is_contiguous_rows(op->src[0]);
36443644
case GGML_OP_ARGSORT:
3645-
// TODO: Support arbitrary column width
3645+
#ifndef GGML_CUDA_USE_CUB
36463646
return op->src[0]->ne[0] <= 1024;
3647+
#else
3648+
return true;
3649+
#endif
36473650
case GGML_OP_SUM_ROWS:
36483651
case GGML_OP_MEAN:
36493652
case GGML_OP_GROUP_NORM:

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
9696

9797
#define GGML_VK_MAX_NODES 8192
9898

99-
#define MAX_VK_BUFFERS 256
100-
10199
#define VK_CHECK(err, msg) \
102100
do { \
103101
vk::Result err_ = (err); \
@@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector {
13111309
std::vector<vk_semaphore> tl_semaphores;
13121310
std::vector<vk_semaphore> semaphores;
13131311
std::vector<vk::Event> events;
1314-
std::vector<vk_buffer> temp_buffers;
13151312
std::vector<vk_context> contexts;
13161313
};
13171314

@@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context {
14821479
// and set to true after the buffer contents are consumed.
14831480
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
14841481

1485-
vk_buffer buffer_pool[MAX_VK_BUFFERS];
1486-
14871482
vk_context_ref compute_ctx;
14881483
vk_context_ref transfer_ctx;
14891484

@@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
36233618

36243619
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
36253620

3626-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3627-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
3621+
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3622+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3623+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3624+
} else {
3625+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3626+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3627+
}
36283628

36293629
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
36303630

@@ -5144,71 +5144,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
51445144
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
51455145
}
51465146

5147-
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
5148-
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
5149-
VK_LOG_MEMORY("ggml_vk_pool_malloc");
5150-
5151-
int best_i = -1;
5152-
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
5153-
int worst_i = -1;
5154-
size_t worst_size = 0; //largest unused buffer seen so far
5155-
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5156-
vk_buffer &b = ctx->buffer_pool[i];
5157-
if (b != nullptr && b->size >= size && b->size < best_size) {
5158-
best_i = i;
5159-
best_size = b->size;
5160-
}
5161-
if (b != nullptr && b->size > worst_size) {
5162-
worst_i = i;
5163-
worst_size = b->size;
5164-
}
5165-
}
5166-
if(best_i != -1) {
5167-
//found the smallest buffer that fits our needs
5168-
vk_buffer b = ctx->buffer_pool[best_i];
5169-
ctx->buffer_pool[best_i].reset();
5170-
return b;
5171-
}
5172-
if(worst_i != -1) {
5173-
//no buffer that fits our needs, resize largest one to save memory
5174-
vk_buffer& b = ctx->buffer_pool[worst_i];
5175-
ggml_vk_destroy_buffer(b);
5176-
}
5177-
5178-
return ggml_vk_create_buffer_device(ctx->device, size);
5179-
}
5180-
5181-
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
5182-
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
5183-
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5184-
vk_buffer& b = ctx->buffer_pool[i];
5185-
if (b == nullptr) {
5186-
b = buffer;
5187-
return;
5188-
}
5189-
}
5190-
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
5191-
ggml_vk_destroy_buffer(buffer);
5192-
}
5193-
5194-
// Returns an available temporary buffer that may only be used temporarily, it will be reused
5195-
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
5196-
// Try to find existing temp buffer with enough capacity
5197-
for (auto& buffer : ctx->gc.temp_buffers) {
5198-
if (buffer->size >= size) {
5199-
return buffer;
5200-
}
5201-
}
5202-
5203-
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
5204-
5205-
// Otherwise create new buffer
5206-
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
5207-
ctx->gc.temp_buffers.push_back(buf);
5208-
5209-
return buf;
5210-
}
5211-
52125147
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
52135148
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
52145149
vk_buffer buf = ggml_vk_create_buffer(device, size,
@@ -11789,10 +11724,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1178911724
// Clean up after graph processing is done
1179011725
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
1179111726
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
11792-
for (auto& buffer : ctx->gc.temp_buffers) {
11793-
ggml_vk_pool_free(ctx, buffer);
11794-
}
11795-
ctx->gc.temp_buffers.clear();
1179611727
ctx->prealloc_y_last_pipeline_used = {};
1179711728

1179811729
ctx->unsynced_nodes_written.clear();
@@ -11835,10 +11766,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
1183511766
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
1183611767
ctx->prealloc_y_last_pipeline_used = nullptr;
1183711768

11838-
for (auto& buffer : ctx->buffer_pool) {
11839-
ggml_vk_destroy_buffer(buffer);
11840-
}
11841-
1184211769
ctx->prealloc_size_x = 0;
1184311770
ctx->prealloc_size_y = 0;
1184411771
ctx->prealloc_size_split_k = 0;

0 commit comments

Comments
 (0)