Skip to content

Commit 63e0a0f

Browse files
authored
[Thrust] Increase static workspace size (#16937)
This PR increases the thrust workspace size, since in practice we found that the current workspace size can still be insufficient. Thrust sort may require larger workspace when the number of elements being sorted is large (e.g., in Llama3 that is 128k).
1 parent 3ff3daa commit 63e0a0f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ def estimate_thrust_workspace_size(self, call: relax.Call) -> int:
227227
int32_byte_per_elem = DataType("int32").bits // 8
228228
num_elem = reduce(mul, input_shape, 1)
229229
input_size = num_elem * input_byte_per_elem
230-
# Most GPU algorithms take O(n) space or less, we choose 8N + 4MB as a safe estimation
230+
# Most GPU algorithms take O(n) space or less, we choose 8N + 8MB as a safe estimation
231231
# for algorithm workspace.
232232
# The current thrust sort implementation may need extra int64 and int32 arrays
233233
# for temporary data, so we further add this part to the workspace.
234234
return (
235235
8 * input_size
236-
+ 4 * 1024 * 1024
236+
+ 8 * 1024 * 1024
237237
+ num_elem * (int64_byte_per_elem + int32_byte_per_elem)
238238
)
239239

0 commit comments

Comments
 (0)