Skip to content

Commit 95c373f

Browse files
authored
[FastMath] Disable default TVM fastmath intrinsic dispatch and add explicit fastmath op to invoke (#875)
* Add fast math operations for CUDA: exp, exp10, log, log2, log10, tan, cos, and sin (#865) * Refactor fast math operation definitions for consistency and readability in CUDA code. Consolidated multiple definitions into single lines and improved formatting in related test files for better clarity. * Remove unnecessary pass configurations for warp specialization and TMA lowering in fast math operation tests for CUDA. This simplifies the test setup while maintaining the focus on fast math functionality. * Update fastmath tests to reflect that tl.* intrinsics generate no fastmath versions and disable cache in main execution. * Fix formatting in fastmath test comments for clarity on tl.* intrinsics behavior. * Add precision comparison tool for CUDA operations This commit introduces a new Python script and CUDA source file for a precision comparison tool that evaluates the accuracy of various CUDA operations (including division, reciprocal, exponential, logarithmic, and trigonometric functions) across different implementations: CUDA Precise, CUDA Fast, Triton, Triton LibDevice, and TileLang. The tool generates test data, executes the operations, and summarizes the error statistics for each implementation against a double precision reference. Additionally, a README file is added to document the results of the comparisons for various operations. * Add precision comparison tool for CUDA operations This commit introduces a new precision comparison tool implemented in Python and CUDA, designed to evaluate the accuracy of various mathematical operations (division, reciprocal, exponential, logarithmic, trigonometric, square root, etc.) across different frameworks including CUDA Precise/Fast, Triton, Triton LibDevice, PyTorch, and TileLang. The tool includes functionality for generating test data, executing operations, and summarizing error statistics for each implementation. Additionally, it provides a comprehensive README with error metrics for each operation tested.
1 parent 56f7494 commit 95c373f

File tree

10 files changed

+1450
-1
lines changed

10 files changed

+1450
-1
lines changed

3rdparty/tvm

Submodule tvm updated from 7a71ee3 to 883e96b

maint/precision/README.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
=== div ===
2+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
3+
------------------------------------------------------------------------------------------
4+
FP32 Precise vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
5+
Triton LibDevice vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
6+
TileLang vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
7+
PyTorch vs Double max abs: 1.219e-04, mean abs: 8.916e-08, max rel: 5.952e-08, mean rel: 2.152e-08
8+
Triton vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
9+
TileLang Fastmath vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
10+
CUDA Fast vs Double max abs: 2.605e-04, mean abs: 1.285e-07, max rel: 1.455e-07, mean rel: 3.175e-08
11+
12+
=== reciprocal ===
13+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
14+
------------------------------------------------------------------------------------------
15+
FP32 Precise vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
16+
Triton LibDevice vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
17+
TileLang vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
18+
PyTorch vs Double max abs: 3.039e-05, mean abs: 4.418e-08, max rel: 5.960e-08, mean rel: 2.235e-08
19+
Triton vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
20+
TileLang Fastmath vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
21+
CUDA Fast vs Double max abs: 4.470e-05, mean abs: 4.886e-08, max rel: 9.699e-08, mean rel: 2.461e-08
22+
23+
=== exp ===
24+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
25+
------------------------------------------------------------------------------------------
26+
FP32 Precise vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
27+
Triton LibDevice vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
28+
TileLang vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
29+
PyTorch vs Double max abs: 5.494e-06, mean abs: 2.153e-07, max rel: 1.483e-07, mean rel: 3.200e-08
30+
Triton vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
31+
TileLang Fastmath vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
32+
CUDA Fast vs Double max abs: 1.338e-05, mean abs: 5.023e-07, max rel: 2.641e-07, mean rel: 5.564e-08
33+
34+
=== log ===
35+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
36+
------------------------------------------------------------------------------------------
37+
FP32 Precise vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
38+
Triton LibDevice vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
39+
TileLang vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
40+
PyTorch vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
41+
Triton vs Double max abs: 2.684e-07, mean abs: 2.051e-08, max rel: 7.886e-08, mean rel: 2.297e-08
42+
TileLang Fastmath vs Double max abs: 9.087e-07, mean abs: 4.760e-08, max rel: 2.019e-02, mean rel: 3.183e-07
43+
CUDA Fast vs Double max abs: 9.087e-07, mean abs: 4.760e-08, max rel: 2.019e-02, mean rel: 3.183e-07
44+
45+
=== sin ===
46+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
47+
------------------------------------------------------------------------------------------
48+
FP32 Precise vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
49+
Triton LibDevice vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
50+
TileLang vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
51+
PyTorch vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
52+
Triton vs Double max abs: 7.731e-08, mean abs: 1.401e-08, max rel: 1.148e-07, mean rel: 2.492e-08
53+
TileLang Fastmath vs Double max abs: 6.463e-07, mean abs: 1.251e-07, max rel: 7.111e-02, mean rel: 1.425e-06
54+
CUDA Fast vs Double max abs: 6.463e-07, mean abs: 1.251e-07, max rel: 7.111e-02, mean rel: 1.425e-06
55+
56+
=== cos ===
57+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
58+
------------------------------------------------------------------------------------------
59+
FP32 Precise vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
60+
Triton LibDevice vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
61+
TileLang vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
62+
PyTorch vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
63+
Triton vs Double max abs: 8.668e-08, mean abs: 1.587e-08, max rel: 1.199e-07, mean rel: 2.513e-08
64+
TileLang Fastmath vs Double max abs: 4.006e-07, mean abs: 9.249e-08, max rel: 5.275e-02, mean rel: 7.307e-07
65+
CUDA Fast vs Double max abs: 4.006e-07, mean abs: 9.249e-08, max rel: 5.275e-02, mean rel: 7.307e-07
66+
67+
=== sqrt ===
68+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
69+
------------------------------------------------------------------------------------------
70+
FP32 Precise vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
71+
Triton LibDevice vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
72+
TileLang vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
73+
PyTorch vs Double max abs: 5.960e-08, mean abs: 2.554e-08, max rel: 5.960e-08, mean rel: 1.986e-08
74+
Triton vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
75+
TileLang Fastmath vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
76+
CUDA Fast vs Double max abs: 1.114e-07, mean abs: 2.947e-08, max rel: 9.962e-08, mean rel: 2.291e-08
77+
78+
=== tanh ===
79+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
80+
------------------------------------------------------------------------------------------
81+
FP32 Precise vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
82+
Triton LibDevice vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
83+
TileLang vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
84+
PyTorch vs Double max abs: 1.056e-07, mean abs: 1.636e-08, max rel: 1.966e-07, mean rel: 2.359e-08
85+
Triton vs Double max abs: 2.293e-07, mean abs: 3.965e-08, max rel: 6.204e-04, mean rel: 1.100e-07
86+
TileLang Fastmath vs Double max abs: 7.826e-06, mean abs: 1.384e-06, max rel: 1.081e-05, mean rel: 1.906e-06
87+
CUDA Fast vs Double max abs: 7.826e-06, mean abs: 1.384e-06, max rel: 1.081e-05, mean rel: 1.906e-06
88+
89+
=== rsqrt ===
90+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
91+
------------------------------------------------------------------------------------------
92+
FP32 Precise vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
93+
Triton LibDevice vs Double max abs: 9.535e-07, mean abs: 2.199e-08, max rel: 5.960e-08, mean rel: 2.315e-08
94+
TileLang vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
95+
PyTorch vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
96+
Triton vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
97+
TileLang Fastmath vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
98+
CUDA Fast vs Double max abs: 2.057e-06, mean abs: 2.798e-08, max rel: 1.224e-07, mean rel: 2.918e-08
99+
100+
=== inv_sqrt ===
101+
Implementation Max Abs Error Mean Abs Error Max Rel Error Mean Rel Error
102+
------------------------------------------------------------------------------------------
103+
FP32 Precise vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
104+
Triton LibDevice vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
105+
TileLang vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
106+
PyTorch vs Double max abs: 2.501e-06, mean abs: 2.911e-08, max rel: 8.939e-08, mean rel: 2.963e-08
107+
Triton vs Double max abs: 2.876e-06, mean abs: 3.443e-08, max rel: 1.536e-07, mean rel: 3.503e-08
108+
TileLang Fastmath vs Double max abs: 2.876e-06, mean abs: 3.443e-08, max rel: 1.536e-07, mean rel: 3.503e-08
109+
CUDA Fast vs Double max abs: 2.876e-06, mean abs: 3.171e-08, max rel: 1.250e-07, mean rel: 3.211e-08

0 commit comments

Comments
 (0)