1+ /* **************************************************************************************************
2+  * Copyright (C) 2025 Intel Corporation, All rights reserved. 
3+  * SPDX-License-Identifier: BSD-3-Clause 
4+  * 
5+  * Redistribution and use in source and binary forms, with or without 
6+  * modification, are permitted provided that the following conditions are met: 
7+  * 
8+  * 1. Redistributions of source code must retain the above copyright notice, this 
9+  * list of conditions and the following disclaimer. 
10+  * 
11+  * 2. Redistributions in binary form must reproduce the above copyright notice, 
12+  * this list of conditions and the following disclaimer in the documentation 
13+  * and/or other materials provided with the distribution. 
14+  * 
15+  * 3. Neither the name of the copyright holder nor the names of its 
16+  * contributors may be used to endorse or promote products derived from 
17+  * this software without specific prior written permission. 
18+  * 
19+  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
20+  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 
21+  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 
22+  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 
23+  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 
24+  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 
25+  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 
26+  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 
27+  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
28+  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 
29+  * 
30+  ***************************************************************************************************/  
31+ 
32+ 
33+ 
34+ #include  < exception> 
35+ #include  < iostream> 
36+ #include  < memory> 
37+ #include  < random> 
38+ #include  < vector> 
39+ 
40+ #include  " cute/tensor.hpp" 
41+ #include  " cutlass/cutlass.h" 
42+ #include  " cutlass/numeric_types.h" 
43+ #include  " cutlass/tensor_ref.h" 
44+ #include  " cutlass/util/host_tensor.h" 
45+ #include  " cutlass/util/reference/host/tensor_fill.h" 
46+ #include  " cutlass/util/reference/device/tensor_fill.h" 
47+ #include  " cutlass/util/device_memory.h" 
48+ 
49+ #include  " cutlass/gemm/gemm.h" 
50+ #include  " cutlass/gemm/device/gemm_universal.h" 
51+ #include  " cutlass/gemm/device/gemm_universal_adapter.h" 
52+ #include  " cutlass/gemm/kernel/gemm_universal.hpp" 
53+ // #include "cutlass/gemm/device/gemm_sparse.h"
54+ #include  " cutlass/gemm/collective/collective_builder.hpp" 
55+ #include  " cutlass/epilogue/collective/collective_builder.hpp" 
56+ #include  " cutlass/epilogue/collective/default_epilogue.hpp" 
57+ #include  " cutlass/epilogue/thread/linear_combination.h" 
58+ #include  " cutlass/epilogue/thread/activation.h" 
59+ #include  " cutlass/gemm/dispatch_policy.hpp" 
60+ #include  " cutlass/gemm/kernel/tile_scheduler.hpp" 
61+ #include  " cutlass/tensor_ref.h" 
62+ #include  " cutlass/util/distribution.h" 
63+ #include  " cutlass/util/packed_stride.hpp" 
64+ #include  " cutlass/util/tensor_view_io.h" 
65+ 
66+ 
67+ //  We compile all models with -fvisibility=hidden. Any symbols that need to be
68+ //  exposed in the final shared library must be declared with PT_EXPORT to make
69+ //  them visible.
70+ #ifdef  __GNUC__ //  Applies to any compiler with GNU extensions (clang and g++)
71+ #define  PT_EXPORT  __attribute__ ((__visibility__(" default" 
72+ #else 
73+ #ifdef  _WIN32
74+ #define  PT_EXPORT  __declspec (dllexport)
75+ #else 
76+ #define  PT_EXPORT 
77+ #endif 
78+ #endif 
79+ 
80+ using  namespace  cute ; 
81+ #define  CUTLASS_CHECK (status )                                                      \
82+ {                                                                                  \
83+   cutlass::Status error = status;                                                  \
84+   if  (error != cutlass::Status::kSuccess ) {                                        \
85+     auto  msg = std::string (" [" " ] Got cutlass error: " 
86+         cutlassGetStatusString (error) + "  at: " std::to_string (__LINE__);        \
87+     throw  std::runtime_error (msg);                                                 \
88+   }                                                                                \
89+ }
90+ 
91+ //  Used as pass-through functor in EVT just for type casting / rounding
92+ template  <typename  T>
93+ struct  identity_op  {
94+   CUTLASS_HOST_DEVICE
95+   T operator ()(T val) const  { return  val; }
96+ };
97+ 
98+ 
99+ 
100+ using  cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue =
101+   typename  cutlass::epilogue::collective::CollectiveBuilder<
102+     cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
103+     cute::Shape<cute::_256, cute::_256, cute::_32>,
104+     cute::Shape<cute::_1, cute::_1, cute::_1>,
105+     cutlass::epilogue::collective::EpilogueTileAuto,
106+     float , float ,
107+     float , cutlass::layout::RowMajor, 4 ,
108+     float , cutlass::layout::RowMajor, 4 ,
109+     cutlass::epilogue::collective::EpilogueScheduleAuto,
110+     cutlass::epilogue::fusion::LinearCombination<
111+       float ,
112+       float ,
113+       float ,
114+       float 
115+     >
116+   >::CollectiveOp;
117+ 
118+ using  cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop =
119+   typename  cutlass::gemm::collective::CollectiveBuilder<
120+     cutlass::arch::Xe20, cutlass::arch::OpClassTensorOp,
121+     cutlass::bfloat16_t , cutlass::layout::ColumnMajor, 8 ,
122+     cutlass::bfloat16_t , cutlass::layout::ColumnMajor, 8 ,
123+     float ,
124+     cute::Shape<cute::_256, cute::_256, cute::_32>,
125+     cute::Shape<cute::_1, cute::_1, cute::_1>,
126+     cutlass::gemm::collective::StageCountAuto,
127+     cutlass::gemm::collective::KernelScheduleAuto
128+   >::CollectiveOp;
129+ 
130+ //  Gemm operator cutlass3x_xe11_tensorop_gemm_bf16_128x256_16x0_tn_align2
131+ using  cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base = cutlass::gemm::kernel::GemmUniversal<
132+     cute::Shape<int ,int ,int ,int >,
133+     cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_mainloop,
134+     cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_epilogue,
135+     cutlass::gemm::PersistentScheduler>;
136+ 
137+ //  Define named type
138+ struct  cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8  :
139+ public cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_base { };
140+ 
141+ 
142+   using  cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type = cutlass::gemm::device::GemmUniversalAdapter<cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8>;
143+ 
144+ //  When workspace_size is not a nullptr, populates requested workspace_size and returns.
145+ //  Otherwise, computes the Gemm kernel using the given workspace ptr.
146+ extern  " C" 
147+ PT_EXPORT int  sycl_tla_gemm_xe20_bf16 (const  uint16_t * X, const  uint16_t * W, uint16_t * Y, const  int  M, const  int  N, const  int  K, const  int  B, const  int  lda, const  int  ldb, const  int  ldc, const  int  ldd, const  int  X_offset, const  int  W_offset, const  int  Y_offset, const  uint8_t  swizzle, size_t * workspace_size, uint8_t * workspace, sycl::queue* stream) {
148+   try  {
149+   using  ElementComputeEpilogue = cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::ElementAccumulator;
150+   using  coord_t  = cutlass::gemm::GemmCoord::Index;
151+   static  cutlass::KernelHardwareInfo hw_info;
152+   if  (hw_info.sm_count  == 0 ) {
153+     hw_info.sm_count  = cutlass::KernelHardwareInfo::query_device_multiprocessor_count (0 );
154+     CUTLASS_TRACE_HOST (" Query result for SM count per device: " sm_count );
155+   }
156+ 
157+   //  Initialize GemmUniversal3xInstance arguments using constructor
158+   cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type::Arguments arguments{
159+     cutlass::gemm::GemmUniversalMode::kGemm ,  //  GemmUniversalMode mode
160+     {
161+       static_cast <coord_t >(M),
162+       static_cast <coord_t >(N),
163+       static_cast <coord_t >(K),
164+       static_cast <coord_t >(B)
165+     }, //  ProblemShape problem_shape
166+     {
167+       (cutlass::bfloat16_t *)(X + X_offset),  //  ElementA const* ptr_A
168+       cute::make_tuple (cute::Int<1 >{}, int64_t (lda), int64_t (0 )),  //  StrideA dA (column-major: stride_m=1, stride_n=lda, batch=0)
169+       (cutlass::bfloat16_t *)(W + W_offset),  //  ElementB const* ptr_B
170+       cute::make_tuple (int64_t (ldb), cute::Int<1 >{}, int64_t (0 )),  //  StrideB dB (column-major: stride_m=ldb, stride_n=1, batch=0)
171+     },  //  MainloopArguments mainloop
172+ 
173+     //  see https://tinyurl.com/4rk89z48
174+     {
175+       {ElementComputeEpilogue (1 ), ElementComputeEpilogue (0 )},  //  thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT )
176+       nullptr ,  //  ElementC const* ptr_C
177+       cute::make_tuple (int64_t (0 ), cute::Int<1 >{}, int64_t (0 )),  //  StrideC dC (row-major: stride_m, stride_n=1, batch=0)
178+       (float *)(Y + Y_offset),  //  ElementD ptr_D (output is float, not bfloat16)
179+       cute::make_tuple (int64_t (ldd), cute::Int<1 >{}, int64_t (0 )),  //  StrideD dD (row-major: stride_m=ldd, stride_n=1, batch=0)
180+     },  //  EpilogueArguments epilogue,
181+     hw_info
182+   };
183+   arguments.scheduler .max_swizzle_size  = swizzle;
184+   cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8_device_type gemm_op;
185+   if  (workspace_size) {
186+     *workspace_size = gemm_op.get_workspace_size (arguments);
187+     return  0 ;
188+   }
189+   //  check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers
190+ #ifndef  CUTLASS_BACKEND_DISABLE_CHECKS
191+   {
192+     auto  status = gemm_op.can_implement (arguments);
193+     CUTLASS_CHECK (status);
194+   }
195+ #endif 
196+ #ifdef  CUTLASS_DEBUG_TRACE_LEVEL
197+ #if  CUTLASS_DEBUG_TRACE_LEVEL == 1
198+   {
199+     //  Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1
200+     //  we don't need a print statement, it's happening inside the function.
201+     gemm_op.maximum_active_blocks ();
202+   }
203+ #endif 
204+ #endif 
205+   {
206+     auto  status = gemm_op.initialize (arguments, workspace, stream);
207+     CUTLASS_CHECK (status);
208+   }
209+   {
210+     auto  status = gemm_op (stream);
211+     CUTLASS_CHECK (status);
212+   }
213+   }
214+   catch  (std::exception& e) {
215+     std::cerr << " Runtime error: " what () << std::endl;
216+     return  -1 ;
217+   }
218+   catch  (...) {
219+     return  -1 ;
220+   }
221+   return  0 ;
222+ }
223+ }
224+ 
225+ //  configuration name: cutlass3x_xe20_tensorop_gemm_bf16_256x256_32x0_nn_align8
0 commit comments