diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 63ce81074667..f66cddf37d21 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -4,17 +4,6 @@ import triton.language as tl -def next_power_of_2(n): - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - - def num_warps(N): if N < 2048: return 4 @@ -24,7 +13,7 @@ def num_warps(N): @triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) -@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) @triton.jit def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) @@ -49,7 +38,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): @triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) -@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) @triton.jit def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0)