Skip to content

Commit ec66984

Browse files
committed
Some progress on shared slices
1 parent 6a36b22 commit ec66984

File tree

5 files changed

+160
-81
lines changed

5 files changed

+160
-81
lines changed

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
any(all(not(feature = "host"), target_os = "cuda"), doc),
1717
feature(asm_const)
1818
)]
19+
#![cfg_attr(target_os = "cuda", feature(ptr_metadata))]
1920
#![cfg_attr(any(feature = "alloc", doc), feature(allocator_api))]
2021
#![feature(doc_cfg)]
2122
#![feature(cfg_version)]

src/safety/stack_only.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ mod sealed {
3737
impl<T> !StackOnly for &mut T {}
3838

3939
impl<T: 'static> !StackOnly for crate::utils::shared::r#static::ThreadBlockShared<T> {}
40-
// impl<T: 'static> !StackOnly for
41-
// crate::utils::shared::slice::ThreadBlockSharedSlice<T> {}
40+
impl<T: 'static + ~const const_type_layout::TypeGraphLayout> !StackOnly
41+
for crate::utils::shared::slice::ThreadBlockSharedSlice<T>
42+
{
43+
}
4244

4345
impl<T> StackOnly for core::marker::PhantomData<T> {}
4446
}

src/utils/shared/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
// pub mod slice;
1+
pub mod slice;
22
pub mod r#static;

src/utils/shared/slice.rs

Lines changed: 119 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,151 @@
1+
#[cfg(not(target_os = "cuda"))]
2+
use core::marker::PhantomData;
3+
4+
use const_type_layout::TypeGraphLayout;
15
use rustacuda_core::DeviceCopy;
26

7+
use crate::common::{CudaAsRust, DeviceAccessible, RustToCuda};
8+
9+
#[cfg(not(target_os = "cuda"))]
10+
#[allow(clippy::module_name_repetitions)]
11+
#[repr(transparent)]
12+
pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> {
13+
len: usize,
14+
marker: PhantomData<T>,
15+
}
16+
17+
#[cfg(target_os = "cuda")]
318
#[allow(clippy::module_name_repetitions)]
19+
#[repr(transparent)]
20+
pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> {
21+
shared: *mut [T],
22+
}
23+
24+
#[doc(hidden)]
425
#[derive(TypeLayout)]
26+
#[layout(bound = "T: 'static + ~const TypeGraphLayout")]
527
#[repr(C)]
6-
pub struct ThreadBlockSharedSlice<T: 'static> {
28+
pub struct ThreadBlockSharedSliceCudaRepresentation<T: 'static + ~const TypeGraphLayout> {
729
len: usize,
8-
byte_offset: usize,
30+
// Note: uses a zero-element array instead of PhantomData here so that
31+
// TypeLayout can still observe T's layout
932
marker: [T; 0],
1033
}
1134

12-
unsafe impl<T: 'static> DeviceCopy for ThreadBlockSharedSlice<T> {}
35+
unsafe impl<T: 'static + ~const TypeGraphLayout> DeviceCopy
36+
for ThreadBlockSharedSliceCudaRepresentation<T>
37+
{
38+
}
1339

14-
#[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
15-
#[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
16-
impl<T: 'static> ThreadBlockSharedSlice<T> {
40+
// #[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
41+
// #[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
42+
impl<T: 'static + ~const TypeGraphLayout> ThreadBlockSharedSlice<T> {
43+
#[cfg(any(not(target_os = "cuda"), doc))]
44+
#[doc(cfg(not(target_os = "cuda")))]
1745
#[must_use]
18-
pub fn with_len(len: usize) -> Self {
46+
pub fn new_uninit_with_len(len: usize) -> Self {
1947
Self {
2048
len,
21-
byte_offset: 0,
22-
marker: [],
49+
marker: PhantomData::<T>,
2350
}
2451
}
2552

53+
#[cfg(not(target_os = "cuda"))]
2654
#[must_use]
2755
pub fn len(&self) -> usize {
2856
self.len
2957
}
3058

59+
#[cfg(target_os = "cuda")]
60+
#[must_use]
61+
pub fn len(&self) -> usize {
62+
core::ptr::metadata(self.shared)
63+
}
64+
3165
#[must_use]
3266
pub fn is_empty(&self) -> bool {
33-
self.len == 0
67+
self.len() == 0
68+
}
69+
70+
#[cfg(any(target_os = "cuda", doc))]
71+
#[doc(cfg(target_os = "cuda"))]
72+
#[must_use]
73+
pub fn as_mut_slice_ptr(&self) -> *mut [T] {
74+
self.shared
75+
}
76+
77+
#[cfg(any(target_os = "cuda", doc))]
78+
#[doc(cfg(target_os = "cuda"))]
79+
#[must_use]
80+
pub fn as_mut_ptr(&self) -> *mut T {
81+
self.shared.cast()
3482
}
3583
}
3684

37-
#[cfg(all(not(feature = "host"), target_os = "cuda"))]
38-
#[doc(cfg(all(not(feature = "host"), target_os = "cuda")))]
39-
impl<T: 'static> ThreadBlockSharedSlice<T> {
40-
/// # Safety
41-
///
42-
/// The thread-block shared dynamic memory must be initialised once and
43-
/// only once per kernel.
44-
pub unsafe fn init() {
45-
unsafe {
46-
core::arch::asm!(
47-
".shared .align {align} .b8 rust_cuda_dynamic_shared[];",
48-
align = const(core::mem::align_of::<T>()),
49-
);
50-
}
85+
unsafe impl<T: 'static + ~const TypeGraphLayout> RustToCuda for ThreadBlockSharedSlice<T> {
86+
#[cfg(feature = "host")]
87+
#[doc(cfg(feature = "host"))]
88+
type CudaAllocation = crate::host::NullCudaAlloc;
89+
type CudaRepresentation = ThreadBlockSharedSliceCudaRepresentation<T>;
90+
91+
#[cfg(feature = "host")]
92+
#[doc(cfg(feature = "host"))]
93+
unsafe fn borrow<A: crate::host::CudaAlloc>(
94+
&self,
95+
alloc: A,
96+
) -> rustacuda::error::CudaResult<(
97+
DeviceAccessible<Self::CudaRepresentation>,
98+
crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>,
99+
)> {
100+
Ok((
101+
DeviceAccessible::from(ThreadBlockSharedSliceCudaRepresentation {
102+
len: self.len,
103+
marker: [],
104+
}),
105+
crate::host::CombinedCudaAlloc::new(crate::host::NullCudaAlloc, alloc),
106+
))
51107
}
52108

53-
/// # Safety
54-
///
55-
/// Exposing the [`ThreadBlockSharedSlice`] must be preceded by exactly one
56-
/// call to [`ThreadBlockSharedSlice::init`] for the type `T` amongst
57-
/// all `ThreadBlockSharedSlice<T>` that has the largest alignment.
58-
pub unsafe fn with_uninit<F: FnOnce(*mut [T]) -> Q, Q>(self, inner: F) -> Q {
59-
let base: *mut u8;
60-
61-
unsafe {
62-
core::arch::asm!(
63-
"cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;",
64-
reg = out(reg64) base,
65-
);
66-
}
109+
#[cfg(feature = "host")]
110+
#[doc(cfg(feature = "host"))]
111+
unsafe fn restore<A: crate::host::CudaAlloc>(
112+
&mut self,
113+
alloc: crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>,
114+
) -> rustacuda::error::CudaResult<A> {
115+
let (_null, alloc): (crate::host::NullCudaAlloc, A) = alloc.split();
116+
117+
Ok(alloc)
118+
}
119+
}
120+
121+
unsafe impl<T: 'static + ~const TypeGraphLayout> CudaAsRust
122+
for ThreadBlockSharedSliceCudaRepresentation<T>
123+
{
124+
type RustRepresentation = ThreadBlockSharedSlice<T>;
125+
126+
#[cfg(any(not(feature = "host"), doc))]
127+
#[doc(cfg(not(feature = "host")))]
128+
unsafe fn as_rust(_this: &DeviceAccessible<Self>) -> Self::RustRepresentation {
129+
todo!()
130+
131+
// unsafe {
132+
// core::arch::asm!(
133+
// ".shared .align {align} .b8 rust_cuda_dynamic_shared[];",
134+
// align = const(core::mem::align_of::<T>()),
135+
// );
136+
// }
137+
138+
// let base: *mut u8;
67139

68-
let slice =
69-
core::ptr::slice_from_raw_parts_mut(base.add(self.byte_offset).cast(), self.len);
140+
// unsafe {
141+
// core::arch::asm!(
142+
// "cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;",
143+
// reg = out(reg64) base,
144+
// );
145+
// }
70146

71-
inner(slice)
147+
// let slice = core::ptr::slice_from_raw_parts_mut(
148+
// base.add(self.byte_offset).cast(), self.len,
149+
// );
72150
}
73151
}

src/utils/shared/static.rs

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,41 @@ pub struct ThreadBlockSharedCudaRepresentation<T: 'static> {
2929

3030
unsafe impl<T: 'static> DeviceCopy for ThreadBlockSharedCudaRepresentation<T> {}
3131

32+
impl<T: 'static> ThreadBlockShared<T> {
33+
#[cfg(not(target_os = "cuda"))]
34+
#[must_use]
35+
pub fn new_uninit() -> Self {
36+
Self {
37+
marker: PhantomData::<T>,
38+
}
39+
}
40+
41+
#[cfg(target_os = "cuda")]
42+
#[must_use]
43+
pub fn new_uninit() -> Self {
44+
let shared: *mut T;
45+
46+
unsafe {
47+
core::arch::asm!(
48+
".shared .align {align} .b8 {reg}_rust_cuda_static_shared[{size}];",
49+
"cvta.shared.u64 {reg}, {reg}_rust_cuda_static_shared;",
50+
reg = out(reg64) shared,
51+
align = const(core::mem::align_of::<T>()),
52+
size = const(core::mem::size_of::<T>()),
53+
);
54+
}
55+
56+
Self { shared }
57+
}
58+
59+
#[cfg(any(target_os = "cuda", doc))]
60+
#[doc(cfg(target_os = "cuda"))]
61+
#[must_use]
62+
pub fn as_mut_ptr(&self) -> *mut T {
63+
self.shared
64+
}
65+
}
66+
3267
unsafe impl<T: 'static + ~const TypeGraphLayout> RustToCuda for ThreadBlockShared<T> {
3368
#[cfg(feature = "host")]
3469
#[doc(cfg(feature = "host"))]
@@ -73,40 +108,3 @@ unsafe impl<T: 'static + ~const TypeGraphLayout> CudaAsRust
73108
ThreadBlockShared::new_uninit()
74109
}
75110
}
76-
77-
#[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))]
78-
#[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))]
79-
impl<T: 'static> ThreadBlockShared<T> {
80-
#[must_use]
81-
pub fn new_uninit() -> Self {
82-
Self {
83-
marker: PhantomData::<T>,
84-
}
85-
}
86-
}
87-
88-
#[cfg(any(all(not(feature = "host"), target_os = "cuda"), doc))]
89-
#[doc(cfg(all(not(feature = "host"), target_os = "cuda")))]
90-
impl<T: 'static> ThreadBlockShared<T> {
91-
#[must_use]
92-
pub fn new_uninit() -> Self {
93-
let shared: *mut T;
94-
95-
unsafe {
96-
core::arch::asm!(
97-
".shared .align {align} .b8 {reg}_rust_cuda_static_shared[{size}];",
98-
"cvta.shared.u64 {reg}, {reg}_rust_cuda_static_shared;",
99-
reg = out(reg64) shared,
100-
align = const(core::mem::align_of::<T>()),
101-
size = const(core::mem::size_of::<T>()),
102-
);
103-
}
104-
105-
Self { shared }
106-
}
107-
108-
#[must_use]
109-
pub fn as_mut_ptr(&self) -> *mut T {
110-
self.shared
111-
}
112-
}

0 commit comments

Comments
 (0)