diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3e824e658bb..337b88a9bc3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -104,8 +104,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return GGML_BACKEND_SPLIT_AXIS_MIRRORED; }; - auto get_split_granularity = [&](ggml_backend_meta_split_axis split_axis) -> int64_t { - const int64_t blck_size = split_axis == GGML_BACKEND_SPLIT_AXIS_1 && tensor->ne[1] % 256 == 0 ? 256 : 32; + auto get_split_granularity = [&](int64_t blck_size) -> int64_t { // attention if (std::regex_match(tensor->name, pattern_q_weight) || std::regex_match(tensor->name, pattern_q_bias) || @@ -140,7 +139,8 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str split_state.axis = get_split_axis(); if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { const int64_t ne_full = tensor->ne[split_state.axis]; - const int64_t granularity = get_split_granularity(split_state.axis); + const int64_t blck_size = ggml_blck_size(tensor->type); + const int64_t granularity = get_split_granularity(blck_size); GGML_ASSERT(ne_full % granularity == 0); const float * tensor_split = ud->model->tensor_split(); std::vector tensor_split_scan;