Skip to content

Commit 6a36b22

Browse files
committed
Revert derive changes + R2C-based approach start
1 parent 6e0e14c commit 6a36b22

File tree

12 files changed

+99
-179
lines changed

12 files changed

+99
-179
lines changed

examples/single-source/src/main.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
extern crate alloc;
1717

18+
use rc::utils::shared::r#static::ThreadBlockShared;
19+
1820
#[cfg(not(target_os = "cuda"))]
1921
fn main() {}
2022

@@ -49,19 +51,21 @@ pub fn kernel<'a, T: rc::common::RustToCuda>(
4951
#[kernel(pass = SafeDeviceCopy, jit)] _v @ _w: &'a core::sync::atomic::AtomicU64,
5052
#[kernel(pass = LendRustToCuda)] _: Wrapper<T>,
5153
#[kernel(pass = SafeDeviceCopy)] Tuple(_s, mut __t): Tuple,
54+
#[kernel(pass = LendRustToCuda)] shared3: ThreadBlockShared<u32>,
5255
) where
5356
<T as rc::common::RustToCuda>::CudaRepresentation: rc::safety::StackOnly,
5457
{
55-
use rc::device::ThreadBlockShared;
56-
5758
let shared: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();
5859
let shared2: ThreadBlockShared<[Tuple; 3]> = ThreadBlockShared::new_uninit();
5960

6061
unsafe {
61-
(*shared.get().cast::<Tuple>().add(1)).0 = 42;
62+
(*shared.as_mut_ptr().cast::<Tuple>().add(1)).0 = 42;
63+
}
64+
unsafe {
65+
(*shared2.as_mut_ptr().cast::<Tuple>().add(2)).1 = 24;
6266
}
6367
unsafe {
64-
(*shared2.get().cast::<Tuple>().add(2)).1 = 24;
68+
*shared3.as_mut_ptr() = 12;
6569
}
6670
}
6771

@@ -89,10 +93,10 @@ mod host {
8993
mod cuda_prelude {
9094
use core::arch::nvptx;
9195

92-
use rc::device::utils;
96+
use rc::device::alloc::PTXAllocator;
9397

9498
#[global_allocator]
95-
static _GLOBAL_ALLOCATOR: utils::PTXAllocator = utils::PTXAllocator;
99+
static _GLOBAL_ALLOCATOR: PTXAllocator = PTXAllocator;
96100

97101
#[panic_handler]
98102
fn panic(_: &::core::panic::PanicInfo) -> ! {

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,16 @@ pub(super) fn quote_kernel_func(
1313
generic_wrapper_where_clause,
1414
..
1515
}: &DeclGenerics,
16-
inputs @ FunctionInputs {
17-
func_inputs,
18-
func_input_cuda_types,
19-
}: &FunctionInputs,
16+
inputs @ FunctionInputs { func_inputs, .. }: &FunctionInputs,
2017
fn_ident @ FuncIdent { func_ident, .. }: &FuncIdent,
2118
func_params: &[syn::Ident],
2219
func_attrs: &[syn::Attribute],
2320
macro_type_ids: &[syn::Ident],
2421
) -> TokenStream {
2522
let new_func_inputs = func_inputs
2623
.iter()
27-
.zip(func_input_cuda_types.iter())
2824
.enumerate()
29-
.map(|(i, (arg, (cuda_type, _)))| match arg {
25+
.map(|(i, arg)| match arg {
3026
syn::FnArg::Typed(syn::PatType {
3127
attrs,
3228
pat,
@@ -50,16 +46,6 @@ pub(super) fn quote_kernel_func(
5046
quote! {
5147
#(#attrs)* #pat #colon_token #and_token #lifetime #mutability #syn_type
5248
}
53-
} else if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
54-
if let syn::Type::Slice(_) = &**ty {
55-
quote! { #(#attrs)* #pat #colon_token
56-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
57-
}
58-
} else {
59-
quote! { #(#attrs)* #pat #colon_token
60-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
61-
}
62-
}
6349
} else {
6450
quote! { #(#attrs)* #pat #colon_token #syn_type }
6551
}
@@ -183,7 +169,6 @@ fn generate_raw_func_input_wrap(
183169
) }
184170
}
185171
},
186-
InputCudaType::ThreadBlockShared => inner,
187172
},
188173
syn::FnArg::Receiver(_) => unreachable!(),
189174
},

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func_async/async_func_types.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,6 @@ pub(super) fn generate_async_func_types(
4646
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
4747
>
4848
},
49-
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(_) = &**ty {
50-
quote! {
51-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
52-
}
53-
} else {
54-
quote! {
55-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
56-
}
57-
},
5849
};
5950

6051
if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/generate/cpu_linker_macro/kernel_func_async/launch_types.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,6 @@ pub(in super::super) fn generate_launch_types(
4747
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
4848
>
4949
},
50-
InputCudaType::ThreadBlockShared => {
51-
if let syn::Type::Slice(_) = &**ty {
52-
quote::quote_spanned! { ty.span()=>
53-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
54-
}
55-
} else {
56-
quote::quote_spanned! { ty.span()=>
57-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
58-
}
59-
}
60-
},
6150
};
6251

6352
cpu_func_types_launch.push(

rust-cuda-derive/src/kernel/wrapper/generate/cpu_wrapper.rs

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ pub(in super::super) fn quote_cpu_wrapper(
9191
}
9292
}
9393

94-
#[allow(clippy::too_many_lines)]
9594
fn generate_new_func_inputs_decl(
9695
crate_path: &syn::Path,
9796
KernelConfig { args, .. }: &KernelConfig,
@@ -133,16 +132,6 @@ fn generate_new_func_inputs_decl(
133132
mutability: *mutability,
134133
elem: syn_type,
135134
}))
136-
} else if matches!(cuda_mode, InputCudaType::ThreadBlockShared) {
137-
if let syn::Type::Slice(_) = &**ty {
138-
syn::parse_quote!(
139-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
140-
)
141-
} else {
142-
syn::parse_quote!(
143-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
144-
)
145-
}
146135
} else {
147136
syn_type
148137
}
@@ -166,15 +155,6 @@ fn generate_new_func_inputs_decl(
166155
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
167156
>
168157
),
169-
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(_) = &**ty {
170-
syn::parse_quote!(
171-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
172-
)
173-
} else {
174-
syn::parse_quote!(
175-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
176-
)
177-
},
178158
};
179159

180160
if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ pub(in super::super) fn quote_cuda_wrapper(
3333
})
3434
.collect::<Vec<_>>();
3535

36-
let mut shared_slice = Vec::new();
37-
3836
let ptx_func_input_unwrap = func_inputs
3937
.iter().zip(func_input_cuda_types.iter()).enumerate()
4038
.rev()
@@ -92,24 +90,7 @@ pub(in super::super) fn quote_cuda_wrapper(
9290
#pat, |#pat: #syn_type| { #inner },
9391
)
9492
}
95-
},
96-
InputCudaType::ThreadBlockShared => if let syn::Type::Slice(syn::TypeSlice { elem, .. }) = &**ty {
97-
shared_slice.push(elem);
98-
99-
quote! {
100-
#ptx_jit_load;
101-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice::with_uninit(
102-
#pat, |#pat: #syn_type| { #inner },
103-
)
104-
}
105-
} else {
106-
quote! {
107-
#ptx_jit_load;
108-
#crate_path::utils::shared::r#static::ThreadBlockShared::with_uninit(
109-
#pat, |#pat: #syn_type| { #inner },
110-
)
111-
}
112-
},
93+
}
11394
}
11495
},
11596
syn::FnArg::Receiver(_) => unreachable!(),
@@ -205,17 +186,6 @@ fn specialise_ptx_func_inputs(
205186
<#syn_type as #crate_path::common::RustToCuda>::CudaRepresentation
206187
>
207188
},
208-
InputCudaType::ThreadBlockShared => {
209-
if let syn::Type::Slice(_) = &**ty {
210-
quote::quote_spanned! { ty.span()=>
211-
#crate_path::utils::shared::slice::ThreadBlockSharedSlice<#syn_type>
212-
}
213-
} else {
214-
quote::quote_spanned! { ty.span()=>
215-
#crate_path::utils::shared::r#static::ThreadBlockShared<#syn_type>
216-
}
217-
}
218-
},
219189
};
220190

221191
let ty = if let syn::Type::Reference(syn::TypeReference {

rust-cuda-derive/src/kernel/wrapper/inputs/attribute.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@ impl syn::parse::Parse for KernelInputAttribute {
1919
let cuda_type = match &*mode.to_string() {
2020
"SafeDeviceCopy" => InputCudaType::SafeDeviceCopy,
2121
"LendRustToCuda" => InputCudaType::LendRustToCuda,
22-
"ThreadBlockShared" => InputCudaType::ThreadBlockShared,
2322
_ => abort!(
2423
mode.span(),
25-
"Unexpected CUDA transfer mode `{}`: Expected `SafeDeviceCopy`, \
26-
`LendRustToCuda`, or `ThreadBlockShared`.",
24+
"Unexpected CUDA transfer mode `{:?}`: Expected `SafeDeviceCopy` or \
25+
`LendRustToCuda`.",
2726
mode
2827
),
2928
};
@@ -62,7 +61,7 @@ impl syn::parse::Parse for KernelInputAttribute {
6261
},
6362
_ => abort!(
6463
ident.span(),
65-
"Unexpected kernel attribute `{}`: Expected `pass` or `jit`.",
64+
"Unexpected kernel attribute `{:?}`: Expected `pass` or `jit`.",
6665
ident
6766
),
6867
}

rust-cuda-derive/src/kernel/wrapper/inputs/mod.rs

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ pub(super) struct FunctionInputs {
1212
pub(super) func_input_cuda_types: Vec<(InputCudaType, InputPtxJit)>,
1313
}
1414

15-
#[allow(clippy::too_many_lines)]
1615
pub(super) fn parse_function_inputs(
1716
func: &syn::ItemFn,
1817
generic_params: &mut syn::punctuated::Punctuated<syn::GenericParam, syn::token::Comma>,
@@ -54,25 +53,9 @@ pub(super) fn parse_function_inputs(
5453

5554
for attr in attrs {
5655
match attr {
57-
KernelInputAttribute::PassType(span, pass_type)
56+
KernelInputAttribute::PassType(_span, pass_type)
5857
if cuda_type.is_none() =>
5958
{
60-
if matches!(pass_type, InputCudaType::ThreadBlockShared)
61-
&& !matches!(
62-
&**ty,
63-
syn::Type::Ptr(syn::TypePtr {
64-
mutability: Some(_),
65-
..
66-
})
67-
)
68-
{
69-
abort!(
70-
span,
71-
"Only mutable pointer types can be shared in a \
72-
thread block."
73-
);
74-
}
75-
7659
cuda_type = Some(pass_type);
7760
},
7861
KernelInputAttribute::PassType(span, _pass_type) => {
@@ -224,17 +207,6 @@ fn ensure_reference_type_lifetime(
224207
elem,
225208
}))
226209
},
227-
ty @ syn::Type::Ptr(syn::TypePtr { elem, .. }) => {
228-
if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
229-
if let syn::Type::Slice(syn::TypeSlice { elem, .. }) = &**elem {
230-
elem.clone()
231-
} else {
232-
elem.clone()
233-
}
234-
} else {
235-
Box::new(ty.clone())
236-
}
237-
},
238210
ty => {
239211
if matches!(cuda_type, InputCudaType::LendRustToCuda) {
240212
generic_params.insert(

rust-cuda-derive/src/kernel/wrapper/mod.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,7 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
205205
.func_inputs
206206
.iter_mut()
207207
.zip(&func_params)
208-
.zip(&func_inputs.func_input_cuda_types)
209-
.zip(&func.sig.inputs)
210-
.map(|(((arg, ident), (cuda_type, _)), arg_orig)| match arg {
208+
.map(|(arg, ident)| match arg {
211209
syn::FnArg::Typed(syn::PatType {
212210
attrs,
213211
colon_token,
@@ -227,12 +225,6 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
227225
ty: ty.clone(),
228226
});
229227

230-
if matches!(cuda_type, InputCudaType::ThreadBlockShared) {
231-
if let syn::FnArg::Typed(syn::PatType { ty: ty_orig, .. }) = arg_orig {
232-
*ty = ty_orig.clone();
233-
}
234-
}
235-
236228
std::mem::replace(arg, ident_fn_arg)
237229
},
238230
syn::FnArg::Receiver(_) => unreachable!(),
@@ -292,7 +284,6 @@ pub fn kernel(attr: TokenStream, func: TokenStream) -> TokenStream {
292284
enum InputCudaType {
293285
SafeDeviceCopy,
294286
LendRustToCuda,
295-
ThreadBlockShared,
296287
}
297288

298289
struct InputPtxJit(bool);

src/safety/stack_only.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,9 @@ mod sealed {
3636
impl<T> !StackOnly for &T {}
3737
impl<T> !StackOnly for &mut T {}
3838

39+
impl<T: 'static> !StackOnly for crate::utils::shared::r#static::ThreadBlockShared<T> {}
40+
// impl<T: 'static> !StackOnly for
41+
// crate::utils::shared::slice::ThreadBlockSharedSlice<T> {}
42+
3943
impl<T> StackOnly for core::marker::PhantomData<T> {}
4044
}

0 commit comments

Comments
 (0)