Skip to content

Commit 6e0e14c

Browse files
committed
First steps towards better shared memory, including dynamic
1 parent 6739fd0 commit 6e0e14c

File tree

18 files changed

+441
-147
lines changed

18 files changed

+441
-147
lines changed

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,20 @@ pub(super) fn quote_kernel_func(
1313
generic_wrapper_where_clause,
1414
..
1515
}: &DeclGenerics,
16-
inputs @ FunctionInputs { func_inputs, .. }: &FunctionInputs,
16+
inputs @ FunctionInputs {
17+
func_inputs,
18+
func_input_cuda_types,
19+
}: &FunctionInputs,
1720
fn_ident @ FuncIdent { func_ident, .. }: &FuncIdent,
1821
func_params: &[syn::Ident],
1922
func_attrs: &[syn::Attribute],
2023
macro_type_ids: &[syn::Ident],
2124
) -> TokenStream {
2225
let new_func_inputs = func_inputs
2326
.iter()
27+
.zip(func_input_cuda_types.iter())
2428
.enumerate()
25-
.map(|(i, arg)| match arg {
29+
.map(|(i, (arg, (cuda_type, _)))| match arg {
2630
syn::FnArg::Typed(syn::PatType {
2731
attrs,
2832
pat,
@@ -46,6 +50,16 @@ pub(super) fn quote_kernel_func(
4650
quote! {
4751
#(#attrs)* #pat #colon_token #and_token #lifetime #mutability #syn_type
4852
}
53+
} else if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
54+
if let syn::Type::Slice(_) = &**ty {
55+
quote! { #(#attrs)* #pat #colon_token
56+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
57+
}
58+
} else {
59+
quote! { #(#attrs)* #pat #colon_token
60+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
61+
}
62+
}
4963
} else {
5064
quote! { #(#attrs)* #pat #colon_token #syn_type }
5165
}
@@ -169,6 +183,7 @@ fn generate_raw_func_input_wrap(
169183
) }
170184
}
171185
},
186+
InputCudaType::ThreadBlockShared => inner,
172187
},
173188
syn::FnArg::Receiver(_) => unreachable!(),
174189
},

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func_async/async_func_types.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ pub(super) fn generate_async_func_types(
4646
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
4747
>
4848
},
49+
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(_) = &**ty {
50+
quote! {
51+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
52+
}
53+
} else {
54+
quote! {
55+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
56+
}
57+
},
4958
};
5059

5160
if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func_async/launch_types.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ pub(in super::super) fn generate_launch_types(
4747
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
4848
>
4949
},
50+
InputCudaType::ThreadBlockShared => {
51+
if let syn::Type::Slice(_) = &**ty {
52+
quote::quote_spanned! { ty.span()=>
53+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
54+
}
55+
} else {
56+
quote::quote_spanned! { ty.span()=>
57+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
58+
}
59+
}
60+
},
5061
};
5162

5263
cpu_func_types_launch.push(

rust-cuda-derive/src/kernel/wrapper/generate/cpu_wrapper.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub(in super::super) fn quote_cpu_wrapper(
9191
}
9292
}
9393

94+
#[allow(clippy::too_many_lines)]
9495
fn generate_new_func_inputs_decl(
9596
crate_path: &syn::Path,
9697
KernelConfig { args, .. }: &KernelConfig,
@@ -132,6 +133,16 @@ fn generate_new_func_inputs_decl(
132133
mutability: *mutability,
133134
elem: syn_type,
134135
}))
136+
} else if matches!(cuda_mode, InputCudaType::ThreadBlockShared) {
137+
if let syn::Type::Slice(_) = &**ty {
138+
syn::parse_quote!(
139+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
140+
)
141+
} else {
142+
syn::parse_quote!(
143+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
144+
)
145+
}
135146
} else {
136147
syn_type
137148
}
@@ -155,6 +166,15 @@ fn generate_new_func_inputs_decl(
155166
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
156167
>
157168
),
169+
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(_) = &**ty {
170+
syn::parse_quote!(
171+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
172+
)
173+
} else {
174+
syn::parse_quote!(
175+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
176+
)
177+
},
158178
};
159179

160180
if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ pub(in super::super) fn quote_cuda_wrapper(
3333
})
3434
.collect::<Vec<_>>();
3535

36+
let mut shared_slice = Vec::new();
37+
3638
let ptx_func_input_unwrap = func_inputs
3739
.iter().zip(func_input_cuda_types.iter()).enumerate()
3840
.rev()
@@ -90,7 +92,24 @@ pub(in super::super) fn quote_cuda_wrapper(
9092
#pat, |#pat: #syn_type| { #inner },
9193
)
9294
}
93-
}
95+
},
96+
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(syn::TypeSlice { elem, .. }) = &**ty {
97+
shared_slice.push(elem);
98+
99+
quote! {
100+
#ptx_jit_load;
101+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice::with_uninit(
102+
#pat, |#pat: #syn_type| { #inner },
103+
)
104+
}
105+
} else {
106+
quote! {
107+
#ptx_jit_load;
108+
#crate_path::utils::shared::r#static::ThreadBlockShared::with_uninit(
109+
#pat, |#pat: #syn_type| { #inner },
110+
)
111+
}
112+
},
94113
}
95114
},
96115
syn::FnArg::Receiver(_) => unreachable!(),
@@ -186,6 +205,17 @@ fn specialise_ptx_func_inputs(
186205
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
187206
>
188207
},
208+
InputCudaType::ThreadBlockShared => {
209+
if let syn::Type::Slice(_) = &**ty {
210+
quote::quote_spanned! { ty.span()=>
211+
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
212+
}
213+
} else {
214+
quote::quote_spanned! { ty.span()=>
215+
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
216+
}
217+
}
218+
},
189219
};
190220

191221
let ty = if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/inputs/attribute.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ impl syn::parse::Parse for KernelInputAttribute {
1919
let cuda_type = match &*mode.to_string() {
2020
"SafeDeviceCopy" => InputCudaType::SafeDeviceCopy,
2121
"LendRustToCuda" => InputCudaType::LendRustToCuda,
22+
"ThreadBlockShared" => InputCudaType::ThreadBlockShared,
2223
_ => abort!(
2324
mode.span(),
24-
"Unexpected CUDA transfer mode `{:?}`: Expected `SafeDeviceCopy` or \
25-
`LendRustToCuda`.",
25+
"Unexpected CUDA transfer mode `{}`: Expected `SafeDeviceCopy`, \
26+
`LendRustToCuda`, or `ThreadBlockShared`.",
2627
mode
2728
),
2829
};
@@ -61,7 +62,7 @@ impl syn::parse::Parse for KernelInputAttribute {
6162
},
6263
_ => abort!(
6364
ident.span(),
64-
"Unexpected kernel attribute `{:?}`: Expected `pass` or `jit`.",
65+
"Unexpected kernel attribute `{}`: Expected `pass` or `jit`.",
6566
ident
6667
),
6768
}

rust-cuda-derive/src/kernel/wrapper/inputs/mod.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub(super) struct FunctionInputs {
1212
pub(super) func_input_cuda_types: Vec<(InputCudaType, InputPtxJit)>,
1313
}
1414

15+
#[allow(clippy::too_many_lines)]
1516
pub(super) fn parse_function_inputs(
1617
func: &syn::ItemFn,
1718
generic_params: &mut syn::punctuated::Punctuated<syn::GenericParam, syn::token::Comma>,
@@ -53,9 +54,25 @@ pub(super) fn parse_function_inputs(
5354

5455
for attr in attrs {
5556
match attr {
56-
KernelInputAttribute::PassType(_span, pass_type)
57+
KernelInputAttribute::PassType(span, pass_type)
5758
if cuda_type.is_none() =>
5859
{
60+
if matches!(pass_type, InputCudaType::ThreadBlockShared)
61+
&& !matches!(
62+
&**ty,
63+
syn::Type::Ptr(syn::TypePtr {
64+
mutability: Some(_),
65+
..
66+
})
67+
)
68+
{
69+
abort!(
70+
span,
71+
"Only mutable pointer types can be shared in a \
72+
thread block."
73+
);
74+
}
75+
5976
cuda_type = Some(pass_type);
6077
},
6178
KernelInputAttribute::PassType(span, _pass_type) => {
@@ -207,6 +224,17 @@ fn ensure_reference_type_lifetime(
207224
elem,
208225
}))
209226
},
227+
ty @ syn::Type::Ptr(syn::TypePtr { elem, .. }) => {
228+
if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
229+
if let syn::Type::Slice(syn::TypeSlice { elem, .. }) = &**elem {
230+
elem.clone()
231+
} else {
232+
elem.clone()
233+
}
234+
} else {
235+
Box::new(ty.clone())
236+
}
237+
},
210238
ty => {
211239
if matches!(cuda_type, InputCudaType::LendRustToCuda) {
212240
generic_params.insert(

rust-cuda-derive/src/kernel/wrapper/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
205205
.func_inputs
206206
.iter_mut()
207207
.zip(&func_params)
208-
.map(|(arg, ident)| match arg {
208+
.zip(&func_inputs.func_input_cuda_types)
209+
.zip(&func.sig.inputs)
210+
.map(|(((arg, ident), (cuda_type, _)), arg_orig)| match arg {
209211
syn::FnArg::Typed(syn::PatType {
210212
attrs,
211213
colon_token,
@@ -225,6 +227,12 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
225227
ty: ty.clone(),
226228
});
227229

230+
if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
231+
if let syn::FnArg::Typed(syn::PatType { ty: ty_orig, .. }) = arg_orig {
232+
*ty = ty_orig.clone();
233+
}
234+
}
235+
228236
std::mem::replace(arg, ident_fn_arg)
229237
},
230238
syn::FnArg::Receiver(_) => unreachable!(),
@@ -284,6 +292,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
284292
enum InputCudaType {
285293
SafeDeviceCopy,
286294
LendRustToCuda,
295+
ThreadBlockShared,
287296
}
288297

289298
struct InputPtxJit(bool);

src/device/alloc.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use alloc::alloc::{GlobalAlloc, Layout};
2+
#[cfg(target_os = "cuda")]
3+
use core::arch::nvptx;
4+
5+
/// Memory allocator using CUDA malloc/free
6+
pub struct PTXAllocator;
7+
8+
unsafe impl GlobalAlloc for PTXAllocator {
9+
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
10+
nvptx::malloc(layout.size()).cast()
11+
}
12+
13+
unsafe fn dealloc(&self, ptr: *mut u8, _layout: Layout) {
14+
nvptx::free(ptr.cast());
15+
}
16+
}

0 commit comments

Comments
 (0)