@@ -108,10 +108,10 @@ LinearPerDeviceState init_kernel(PerDeviceFFHandle handle,
108108
109109void forward_kernel (cudaStream_t stream,
110110 LinearPerDeviceState const &m,
111- void const *input_ptr,
112- void *output_ptr,
113- void const *weight_ptr,
114- void const *bias_ptr,
111+ float const *input_ptr,
112+ float *output_ptr,
113+ float const *weight_ptr,
114+ float const *bias_ptr,
115115 int in_dim,
116116 int out_dim,
117117 int batch_size) {
@@ -135,14 +135,14 @@ void forward_kernel(cudaStream_t stream,
135135 batch_size,
136136 in_dim,
137137 &alpha,
138- weight_ptr,
138+ ( void *) weight_ptr,
139139 weight_type,
140140 in_dim,
141- input_ptr,
141+ ( void *) input_ptr,
142142 input_type,
143143 in_dim,
144144 &beta,
145- output_ptr,
145+ ( void *) output_ptr,
146146 output_type,
147147 out_dim,
148148 compute_type,
@@ -156,14 +156,14 @@ void forward_kernel(cudaStream_t stream,
156156 batch_size,
157157 1 ,
158158 &alpha,
159- bias_ptr,
159+ ( void *) bias_ptr,
160160 weight_type,
161161 1 ,
162- m.one_ptr ,
162+ ( void *) m.one_ptr ,
163163 CUDA_R_32F,
164164 1 ,
165165 &alpha,
166- output_ptr,
166+ ( void *) output_ptr,
167167 output_type,
168168 out_dim,
169169 compute_type,
@@ -174,10 +174,10 @@ void forward_kernel(cudaStream_t stream,
174174 m.actiDesc ,
175175 &alpha,
176176 m.outputTensor ,
177- output_ptr,
177+ ( void *) output_ptr,
178178 &beta,
179179 m.outputTensor ,
180- output_ptr));
180+ ( void *) output_ptr));
181181 } else if (m.activation == Activation::GELU) {
182182 size_t elements = size_t_from_int (out_dim) * size_t_from_int (batch_size);
183183 constexpr float B = 0 .7978845608028654f ; // sqrt(2.0/M_PI)
@@ -191,13 +191,13 @@ void forward_kernel(cudaStream_t stream,
191191
192192void backward_kernel (cudaStream_t stream,
193193 LinearPerDeviceState const &m,
194- void const *input_ptr,
195- void *input_grad_ptr,
196- void const *output_ptr,
197- void *output_grad_ptr,
198- void const *kernel_ptr,
199- void *kernel_grad_ptr,
200- void *bias_grad_ptr,
194+ float const *input_ptr,
195+ float *input_grad_ptr,
196+ float const *output_ptr,
197+ float *output_grad_ptr,
198+ float const *kernel_ptr,
199+ float *kernel_grad_ptr,
200+ float *bias_grad_ptr,
201201 int in_dim,
202202 int out_dim,
203203 int batch_size) {
@@ -216,11 +216,17 @@ void backward_kernel(cudaStream_t stream,
216216 int output_size = out_dim * batch_size;
217217 if (m.activation .has_value ()) {
218218 if (m.activation == Activation::RELU) {
219- relu_backward_kernel (
220- m.output_type , output_grad_ptr, output_ptr, output_size, stream);
219+ relu_backward_kernel (m.output_type ,
220+ (void *)output_grad_ptr,
221+ (void *)output_ptr,
222+ output_size,
223+ stream);
221224 } else if (m.activation == Activation::SIGMOID) {
222- sigmoid_backward_kernel (
223- m.output_type , output_grad_ptr, output_ptr, output_size, stream);
225+ sigmoid_backward_kernel (m.output_type ,
226+ (void *)output_grad_ptr,
227+ (void *)output_ptr,
228+ output_size,
229+ stream);
224230 } else {
225231 // TODO: only support relu and sigmoid for now
226232 assert (false && " Unsupported activation for Linear" );
@@ -235,14 +241,14 @@ void backward_kernel(cudaStream_t stream,
235241 out_dim,
236242 batch_size,
237243 &alpha,
238- input_ptr,
244+ ( void *) input_ptr,
239245 input_type,
240246 in_dim,
241- output_grad_ptr,
247+ ( void *) output_grad_ptr,
242248 output_type,
243249 out_dim,
244250 &alpha,
245- kernel_grad_ptr,
251+ ( void *) kernel_grad_ptr,
246252 weight_type,
247253 in_dim,
248254 compute_type,
@@ -261,12 +267,12 @@ void backward_kernel(cudaStream_t stream,
261267 in_dim,
262268 out_dim,
263269 &alpha,
264- ( float *) kernel_grad_ptr,
270+ kernel_grad_ptr,
265271 in_dim,
266272 &lambda,
267- ( float *) kernel_ptr,
273+ kernel_ptr,
268274 in_dim,
269- ( float *) kernel_grad_ptr,
275+ kernel_grad_ptr,
270276 in_dim));
271277 } else {
272278 assert (false && " Only L2 regularization is supported" );
@@ -284,14 +290,14 @@ void backward_kernel(cudaStream_t stream,
284290 out_dim,
285291 batch_size,
286292 &alpha,
287- m.one_ptr ,
293+ ( void *) m.one_ptr ,
288294 CUDA_R_32F,
289295 1 ,
290- output_grad_ptr,
296+ ( void *) output_grad_ptr,
291297 output_type,
292298 out_dim,
293299 &alpha,
294- bias_grad_ptr,
300+ ( void *) bias_grad_ptr,
295301 weight_type,
296302 1 ,
297303 compute_type,
@@ -307,14 +313,14 @@ void backward_kernel(cudaStream_t stream,
307313 batch_size,
308314 out_dim,
309315 &alpha,
310- kernel_ptr,
316+ ( void *) kernel_ptr,
311317 weight_type,
312318 in_dim,
313- output_grad_ptr,
319+ ( void *) output_grad_ptr,
314320 output_type,
315321 out_dim,
316322 &alpha,
317- input_grad_ptr,
323+ ( void *) input_grad_ptr,
318324 input_type,
319325 in_dim,
320326 compute_type,
0 commit comments