|
| 1 | +#[cfg(not(target_os = "cuda"))] |
| 2 | +use core::marker::PhantomData; |
| 3 | + |
| 4 | +use const_type_layout::TypeGraphLayout; |
1 | 5 | use rustacuda_core::DeviceCopy; |
2 | 6 |
|
| 7 | +use crate::common::{CudaAsRust, DeviceAccessible, RustToCuda}; |
| 8 | + |
| 9 | +#[cfg(not(target_os = "cuda"))] |
| 10 | +#[allow(clippy::module_name_repetitions)] |
| 11 | +#[repr(transparent)] |
| 12 | +pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> { |
| 13 | + len: usize, |
| 14 | + marker: PhantomData<T>, |
| 15 | +} |
| 16 | + |
| 17 | +#[cfg(target_os = "cuda")] |
3 | 18 | #[allow(clippy::module_name_repetitions)] |
| 19 | +#[repr(transparent)] |
| 20 | +pub struct ThreadBlockSharedSlice<T: 'static + ~const TypeGraphLayout> { |
| 21 | + shared: *mut [T], |
| 22 | +} |
| 23 | + |
| 24 | +#[doc(hidden)] |
4 | 25 | #[derive(TypeLayout)] |
| 26 | +#[layout(bound = "T: 'static + ~const TypeGraphLayout")] |
5 | 27 | #[repr(C)] |
6 | | -pub struct ThreadBlockSharedSlice<T: 'static> { |
| 28 | +pub struct ThreadBlockSharedSliceCudaRepresentation<T: 'static + ~const TypeGraphLayout> { |
7 | 29 | len: usize, |
8 | | - byte_offset: usize, |
| 30 | + // Note: uses a zero-element array instead of PhantomData here so that |
| 31 | + // TypeLayout can still observe T's layout |
9 | 32 | marker: [T; 0], |
10 | 33 | } |
11 | 34 |
|
12 | | -unsafe impl<T: 'static> DeviceCopy for ThreadBlockSharedSlice<T> {} |
| 35 | +unsafe impl<T: 'static + ~const TypeGraphLayout> DeviceCopy |
| 36 | + for ThreadBlockSharedSliceCudaRepresentation<T> |
| 37 | +{ |
| 38 | +} |
13 | 39 |
|
14 | | -#[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))] |
15 | | -#[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))] |
16 | | -impl<T: 'static> ThreadBlockSharedSlice<T> { |
| 40 | +// #[cfg(not(any(all(not(feature = "host"), target_os = "cuda"), doc)))] |
| 41 | +// #[doc(cfg(not(all(not(feature = "host"), target_os = "cuda"))))] |
| 42 | +impl<T: 'static + ~const TypeGraphLayout> ThreadBlockSharedSlice<T> { |
| 43 | + #[cfg(any(not(target_os = "cuda"), doc))] |
| 44 | + #[doc(cfg(not(target_os = "cuda")))] |
17 | 45 | #[must_use] |
18 | | - pub fn with_len(len: usize) -> Self { |
| 46 | + pub fn new_uninit_with_len(len: usize) -> Self { |
19 | 47 | Self { |
20 | 48 | len, |
21 | | - byte_offset: 0, |
22 | | - marker: [], |
| 49 | + marker: PhantomData::<T>, |
23 | 50 | } |
24 | 51 | } |
25 | 52 |
|
| 53 | + #[cfg(not(target_os = "cuda"))] |
26 | 54 | #[must_use] |
27 | 55 | pub fn len(&self) -> usize { |
28 | 56 | self.len |
29 | 57 | } |
30 | 58 |
|
| 59 | + #[cfg(target_os = "cuda")] |
| 60 | + #[must_use] |
| 61 | + pub fn len(&self) -> usize { |
| 62 | + core::ptr::metadata(self.shared) |
| 63 | + } |
| 64 | + |
31 | 65 | #[must_use] |
32 | 66 | pub fn is_empty(&self) -> bool { |
33 | | - self.len == 0 |
| 67 | + self.len() == 0 |
| 68 | + } |
| 69 | + |
| 70 | + #[cfg(any(target_os = "cuda", doc))] |
| 71 | + #[doc(cfg(target_os = "cuda"))] |
| 72 | + #[must_use] |
| 73 | + pub fn as_mut_slice_ptr(&self) -> *mut [T] { |
| 74 | + self.shared |
| 75 | + } |
| 76 | + |
| 77 | + #[cfg(any(target_os = "cuda", doc))] |
| 78 | + #[doc(cfg(target_os = "cuda"))] |
| 79 | + #[must_use] |
| 80 | + pub fn as_mut_ptr(&self) -> *mut T { |
| 81 | + self.shared.cast() |
34 | 82 | } |
35 | 83 | } |
36 | 84 |
|
37 | | -#[cfg(all(not(feature = "host"), target_os = "cuda"))] |
38 | | -#[doc(cfg(all(not(feature = "host"), target_os = "cuda")))] |
39 | | -impl<T: 'static> ThreadBlockSharedSlice<T> { |
40 | | - /// # Safety |
41 | | - /// |
42 | | - /// The thread-block shared dynamic memory must be initialised once and |
43 | | - /// only once per kernel. |
44 | | - pub unsafe fn init() { |
45 | | - unsafe { |
46 | | - core::arch::asm!( |
47 | | - ".shared .align {align} .b8 rust_cuda_dynamic_shared[];", |
48 | | - align = const(core::mem::align_of::<T>()), |
49 | | - ); |
50 | | - } |
| 85 | +unsafe impl<T: 'static + ~const TypeGraphLayout> RustToCuda for ThreadBlockSharedSlice<T> { |
| 86 | + #[cfg(feature = "host")] |
| 87 | + #[doc(cfg(feature = "host"))] |
| 88 | + type CudaAllocation = crate::host::NullCudaAlloc; |
| 89 | + type CudaRepresentation = ThreadBlockSharedSliceCudaRepresentation<T>; |
| 90 | + |
| 91 | + #[cfg(feature = "host")] |
| 92 | + #[doc(cfg(feature = "host"))] |
| 93 | + unsafe fn borrow<A: crate::host::CudaAlloc>( |
| 94 | + &self, |
| 95 | + alloc: A, |
| 96 | + ) -> rustacuda::error::CudaResult<( |
| 97 | + DeviceAccessible<Self::CudaRepresentation>, |
| 98 | + crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>, |
| 99 | + )> { |
| 100 | + Ok(( |
| 101 | + DeviceAccessible::from(ThreadBlockSharedSliceCudaRepresentation { |
| 102 | + len: self.len, |
| 103 | + marker: [], |
| 104 | + }), |
| 105 | + crate::host::CombinedCudaAlloc::new(crate::host::NullCudaAlloc, alloc), |
| 106 | + )) |
51 | 107 | } |
52 | 108 |
|
53 | | - /// # Safety |
54 | | - /// |
55 | | - /// Exposing the [`ThreadBlockSharedSlice`] must be preceded by exactly one |
56 | | - /// call to [`ThreadBlockSharedSlice::init`] for the type `T` amongst |
57 | | - /// all `ThreadBlockSharedSlice<T>` that has the largest alignment. |
58 | | - pub unsafe fn with_uninit<F: FnOnce(*mut [T]) -> Q, Q>(self, inner: F) -> Q { |
59 | | - let base: *mut u8; |
60 | | - |
61 | | - unsafe { |
62 | | - core::arch::asm!( |
63 | | - "cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;", |
64 | | - reg = out(reg64) base, |
65 | | - ); |
66 | | - } |
| 109 | + #[cfg(feature = "host")] |
| 110 | + #[doc(cfg(feature = "host"))] |
| 111 | + unsafe fn restore<A: crate::host::CudaAlloc>( |
| 112 | + &mut self, |
| 113 | + alloc: crate::host::CombinedCudaAlloc<Self::CudaAllocation, A>, |
| 114 | + ) -> rustacuda::error::CudaResult<A> { |
| 115 | + let (_null, alloc): (crate::host::NullCudaAlloc, A) = alloc.split(); |
| 116 | + |
| 117 | + Ok(alloc) |
| 118 | + } |
| 119 | +} |
| 120 | + |
| 121 | +unsafe impl<T: 'static + ~const TypeGraphLayout> CudaAsRust |
| 122 | + for ThreadBlockSharedSliceCudaRepresentation<T> |
| 123 | +{ |
| 124 | + type RustRepresentation = ThreadBlockSharedSlice<T>; |
| 125 | + |
| 126 | + #[cfg(any(not(feature = "host"), doc))] |
| 127 | + #[doc(cfg(not(feature = "host")))] |
| 128 | + unsafe fn as_rust(_this: &DeviceAccessible<Self>) -> Self::RustRepresentation { |
| 129 | + todo!() |
| 130 | + |
| 131 | + // unsafe { |
| 132 | + // core::arch::asm!( |
| 133 | + // ".shared .align {align} .b8 rust_cuda_dynamic_shared[];", |
| 134 | + // align = const(core::mem::align_of::<T>()), |
| 135 | + // ); |
| 136 | + // } |
| 137 | + |
| 138 | + // let base: *mut u8; |
67 | 139 |
|
68 | | - let slice = |
69 | | - core::ptr::slice_from_raw_parts_mut(base.add(self.byte_offset).cast(), self.len); |
| 140 | + // unsafe { |
| 141 | + // core::arch::asm!( |
| 142 | + // "cvta.shared.u64 {reg}, rust_cuda_dynamic_shared;", |
| 143 | + // reg = out(reg64) base, |
| 144 | + // ); |
| 145 | + // } |
70 | 146 |
|
71 | | - inner(slice) |
| 147 | + // let slice = core::ptr::slice_from_raw_parts_mut( |
| 148 | + // base.add(self.byte_offset).cast(), self.len, |
| 149 | + // ); |
72 | 150 | } |
73 | 151 | } |
0 commit comments