Skip to content

Commit 7a258bf

Browse files
committed
Review: format + micro-optimizations
1 parent 81eb590 commit 7a258bf

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
3232
#pragma unroll
3333
for (int i = 0; i < n_experts; i += WARP_SIZE) {
3434
const int expert = i + threadIdx.x;
35-
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
35+
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
3636
}
3737

38-
float max_val = -INFINITY;
38+
float max_val = logits_r[0];
3939

4040
#pragma unroll
41-
for (int i = 0; i < experts_per_thread; i++) {
41+
for (int i = 1; i < experts_per_thread; i++) {
4242
const float val = logits_r[i];
4343
max_val = max(val, max_val);
4444
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4408,12 +4408,10 @@ struct test_argsort : public test_case {
44084408
};
44094409

44104410
struct test_topk_moe: public test_case {
4411-
44124411
const std::array<int64_t, 4> ne;
44134412
const int n_expert_used;
44144413
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1)
4415-
: ne(ne),
4416-
n_expert_used(n_expert_used) {
4414+
: ne(ne), n_expert_used(n_expert_used) {
44174415
GGML_ASSERT(n_expert_used <= ne[0]);
44184416
}
44194417

0 commit comments

Comments
 (0)