Skip to content

Commit a182865

Browse files
committed
flexflow#1409 issue, change datatype for linear kernels away from void *
1 parent d1a15f3 commit a182865

File tree

3 files changed

+59
-53
lines changed

3 files changed

+59
-53
lines changed

lib/kernels/include/kernels/linear_kernels.h

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,23 +50,23 @@ bool use_activation(Activation activation);
5050

5151
void forward_kernel(ffStream_t stream,
5252
LinearPerDeviceState const &m,
53-
void const *input_ptr,
54-
void *output_ptr,
55-
void const *filter_ptr,
56-
void const *bias_ptr,
53+
float const *input_ptr,
54+
float *output_ptr,
55+
float const *filter_ptr,
56+
float const *bias_ptr,
5757
int in_dim,
5858
int out_dim,
5959
int batch_size);
6060

6161
void backward_kernel(ffStream_t stream,
6262
LinearPerDeviceState const &m,
63-
void const *input_ptr,
64-
void *input_grad_ptr,
65-
void const *output_ptr,
66-
void *output_grad_ptr,
67-
void const *kernel_ptr,
68-
void *kernel_grad_ptr,
69-
void *bias_ptr,
63+
float const *input_ptr,
64+
float *input_grad_ptr,
65+
float const *output_ptr,
66+
float *output_grad_ptr,
67+
float const *kernel_ptr,
68+
float *kernel_grad_ptr,
69+
float *bias_ptr,
7070
int in_dim,
7171
int out_dim,
7272
int batch_size);

lib/kernels/src/cuda/ops/linear_kernels.cu

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ LinearPerDeviceState init_kernel(PerDeviceFFHandle handle,
108108

109109
void forward_kernel(cudaStream_t stream,
110110
LinearPerDeviceState const &m,
111-
void const *input_ptr,
112-
void *output_ptr,
113-
void const *weight_ptr,
114-
void const *bias_ptr,
111+
float const *input_ptr,
112+
float *output_ptr,
113+
float const *weight_ptr,
114+
float const *bias_ptr,
115115
int in_dim,
116116
int out_dim,
117117
int batch_size) {
@@ -135,14 +135,14 @@ void forward_kernel(cudaStream_t stream,
135135
batch_size,
136136
in_dim,
137137
&alpha,
138-
weight_ptr,
138+
(void *)weight_ptr,
139139
weight_type,
140140
in_dim,
141-
input_ptr,
141+
(void *)input_ptr,
142142
input_type,
143143
in_dim,
144144
&beta,
145-
output_ptr,
145+
(void *)output_ptr,
146146
output_type,
147147
out_dim,
148148
compute_type,
@@ -156,14 +156,14 @@ void forward_kernel(cudaStream_t stream,
156156
batch_size,
157157
1,
158158
&alpha,
159-
bias_ptr,
159+
(void *)bias_ptr,
160160
weight_type,
161161
1,
162-
m.one_ptr,
162+
(void *)m.one_ptr,
163163
CUDA_R_32F,
164164
1,
165165
&alpha,
166-
output_ptr,
166+
(void *)output_ptr,
167167
output_type,
168168
out_dim,
169169
compute_type,
@@ -174,10 +174,10 @@ void forward_kernel(cudaStream_t stream,
174174
m.actiDesc,
175175
&alpha,
176176
m.outputTensor,
177-
output_ptr,
177+
(void *)output_ptr,
178178
&beta,
179179
m.outputTensor,
180-
output_ptr));
180+
(void *)output_ptr));
181181
} else if (m.activation == Activation::GELU) {
182182
size_t elements = size_t_from_int(out_dim) * size_t_from_int(batch_size);
183183
constexpr float B = 0.7978845608028654f; // sqrt(2.0/M_PI)
@@ -191,13 +191,13 @@ void forward_kernel(cudaStream_t stream,
191191

192192
void backward_kernel(cudaStream_t stream,
193193
LinearPerDeviceState const &m,
194-
void const *input_ptr,
195-
void *input_grad_ptr,
196-
void const *output_ptr,
197-
void *output_grad_ptr,
198-
void const *kernel_ptr,
199-
void *kernel_grad_ptr,
200-
void *bias_grad_ptr,
194+
float const *input_ptr,
195+
float *input_grad_ptr,
196+
float const *output_ptr,
197+
float *output_grad_ptr,
198+
float const *kernel_ptr,
199+
float *kernel_grad_ptr,
200+
float *bias_grad_ptr,
201201
int in_dim,
202202
int out_dim,
203203
int batch_size) {
@@ -216,11 +216,17 @@ void backward_kernel(cudaStream_t stream,
216216
int output_size = out_dim * batch_size;
217217
if (m.activation.has_value()) {
218218
if (m.activation == Activation::RELU) {
219-
relu_backward_kernel(
220-
m.output_type, output_grad_ptr, output_ptr, output_size, stream);
219+
relu_backward_kernel(m.output_type,
220+
(void *)output_grad_ptr,
221+
(void *)output_ptr,
222+
output_size,
223+
stream);
221224
} else if (m.activation == Activation::SIGMOID) {
222-
sigmoid_backward_kernel(
223-
m.output_type, output_grad_ptr, output_ptr, output_size, stream);
225+
sigmoid_backward_kernel(m.output_type,
226+
(void *)output_grad_ptr,
227+
(void *)output_ptr,
228+
output_size,
229+
stream);
224230
} else {
225231
// TODO: only support relu and sigmoid for now
226232
assert(false && "Unsupported activation for Linear");
@@ -235,14 +241,14 @@ void backward_kernel(cudaStream_t stream,
235241
out_dim,
236242
batch_size,
237243
&alpha,
238-
input_ptr,
244+
(void *)input_ptr,
239245
input_type,
240246
in_dim,
241-
output_grad_ptr,
247+
(void *)output_grad_ptr,
242248
output_type,
243249
out_dim,
244250
&alpha,
245-
kernel_grad_ptr,
251+
(void *)kernel_grad_ptr,
246252
weight_type,
247253
in_dim,
248254
compute_type,
@@ -261,12 +267,12 @@ void backward_kernel(cudaStream_t stream,
261267
in_dim,
262268
out_dim,
263269
&alpha,
264-
(float *)kernel_grad_ptr,
270+
kernel_grad_ptr,
265271
in_dim,
266272
&lambda,
267-
(float *)kernel_ptr,
273+
kernel_ptr,
268274
in_dim,
269-
(float *)kernel_grad_ptr,
275+
kernel_grad_ptr,
270276
in_dim));
271277
} else {
272278
assert(false && "Only L2 regularization is supported");
@@ -284,14 +290,14 @@ void backward_kernel(cudaStream_t stream,
284290
out_dim,
285291
batch_size,
286292
&alpha,
287-
m.one_ptr,
293+
(void *)m.one_ptr,
288294
CUDA_R_32F,
289295
1,
290-
output_grad_ptr,
296+
(void *)output_grad_ptr,
291297
output_type,
292298
out_dim,
293299
&alpha,
294-
bias_grad_ptr,
300+
(void *)bias_grad_ptr,
295301
weight_type,
296302
1,
297303
compute_type,
@@ -307,14 +313,14 @@ void backward_kernel(cudaStream_t stream,
307313
batch_size,
308314
out_dim,
309315
&alpha,
310-
kernel_ptr,
316+
(void *)kernel_ptr,
311317
weight_type,
312318
in_dim,
313-
output_grad_ptr,
319+
(void *)output_grad_ptr,
314320
output_type,
315321
out_dim,
316322
&alpha,
317-
input_grad_ptr,
323+
(void *)input_grad_ptr,
318324
input_type,
319325
in_dim,
320326
compute_type,

lib/local-execution/src/ops/linear.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,13 @@ static std::optional<float>
148148
profiling,
149149
"[Linear] backward_time = {:.2lf}ms\n",
150150
per_device_state,
151-
(void *)input.get_float_ptr(),
152-
(void *)input_grad.get_float_ptr(),
153-
(void *)output.get_float_ptr(),
154-
(void *)output_grad.get_float_ptr(),
155-
(void *)weight.get_float_ptr(),
156-
(void *)weight_grad.get_float_ptr(),
157-
(void *)bias_ptr,
151+
input.get_float_ptr(),
152+
(float *)input_grad.get_float_ptr(),
153+
output.get_float_ptr(),
154+
(float *)output_grad.get_float_ptr(),
155+
weight.get_float_ptr(),
156+
(float *)weight_grad.get_float_ptr(),
157+
(float *)bias_ptr,
158158
in_dim,
159159
out_dim,
160160
batch_size);

0 commit comments

Comments
 (0)