Skip to content

Commit 3d42856

Browse files
committed
Simplified the kernel parameter layout extraction from PTX
1 parent 83c965e commit 3d42856

File tree

4 files changed

+103
-104
lines changed

4 files changed

+103
-104
lines changed

rust-cuda-derive/src/kernel/link/mod.rs

Lines changed: 87 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use ptx_builder::{
2121
use super::{
2222
lints::{LintLevel, PtxLint},
2323
utils::skip_kernel_compilation,
24+
KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY,
2425
};
2526

2627
mod config;
@@ -66,14 +67,14 @@ pub fn check_kernel(tokens: TokenStream) -> TokenStream {
6667
quote!(::core::result::Result::Ok(())).into()
6768
}
6869

69-
#[allow(clippy::module_name_repetitions, clippy::too_many_lines)]
70+
#[allow(clippy::module_name_repetitions)]
7071
pub fn link_kernel(tokens: TokenStream) -> TokenStream {
7172
proc_macro_error::set_dummy(quote! {
7273
const PTX_STR: &'static str = "ERROR in this PTX compilation";
7374
});
7475

7576
let LinkKernelConfig {
76-
kernel,
77+
kernel: _kernel,
7778
kernel_hash,
7879
args,
7980
crate_name,
@@ -111,116 +112,110 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
111112
.into()
112113
};
113114

114-
let kernel_layout_name = if specialisation.is_empty() {
115-
format!("{kernel}_type_layout_kernel")
116-
} else {
117-
format!(
118-
"{kernel}_type_layout_kernel_{:016x}",
119-
seahash::hash(specialisation.as_bytes())
120-
)
121-
};
115+
let type_layouts = extract_ptx_kernel_layout(&mut kernel_ptx);
116+
remove_kernel_type_use_from_ptx(&mut kernel_ptx);
122117

123-
let mut type_layouts = Vec::new();
118+
check_kernel_ptx_and_report(
119+
&kernel_ptx,
120+
Specialisation::Link(&specialisation),
121+
&kernel_hash,
122+
&ptx_lint_levels,
123+
);
124+
125+
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
126+
}
124127

125-
let type_layout_start_pattern = format!("\n\t// .globl\t{kernel_layout_name}");
128+
fn extract_ptx_kernel_layout(kernel_ptx: &mut String) -> Vec<proc_macro2::TokenStream> {
129+
const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 ";
130+
const PARAM_LEN_PATTERN: &str = "[";
131+
const LEN_BYTES_PATTERN: &str = "] = {";
132+
const AFTER_BYTES_PATTERN: &str = "};\n";
133+
const BYTES_PARAM_PATTERN: &str = "};";
126134

127-
if let Some(type_layout_start) = kernel_ptx.find(&type_layout_start_pattern) {
128-
const BEFORE_PARAM_PATTERN: &str = "\n.global .align 1 .b8 ";
129-
const PARAM_LEN_PATTERN: &str = "[";
130-
const LEN_BYTES_PATTERN: &str = "] = {";
131-
const AFTER_BYTES_PATTERN: &str = "};\n";
132-
const BYTES_PARAM_PATTERN: &str = "};";
135+
let mut type_layouts = Vec::new();
133136

134-
let after_type_layout_start = type_layout_start + type_layout_start_pattern.len();
137+
while let Some(type_layout_start) = kernel_ptx.find(BEFORE_PARAM_PATTERN) {
138+
let param_start = type_layout_start + BEFORE_PARAM_PATTERN.len();
135139

136-
let Some(type_layout_middle) = kernel_ptx[after_type_layout_start..]
137-
.find(&format!(".visible .entry {kernel_layout_name}")).map(|i| after_type_layout_start + i)
138-
else {
140+
let Some(len_start_offset) = kernel_ptx[param_start..].find(PARAM_LEN_PATTERN) else {
139141
abort_call_site!(
140-
"Kernel compilation generated invalid PTX: incomplete type layout information"
142+
"Kernel compilation generated invalid PTX: missing type layout data"
141143
)
142144
};
145+
let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len();
143146

144-
let mut next_type_layout = after_type_layout_start;
147+
let Some(bytes_start_offset) = kernel_ptx[len_start..].find(LEN_BYTES_PATTERN) else {
148+
abort_call_site!(
149+
"Kernel compilation generated invalid PTX: missing type layout length"
150+
)
151+
};
152+
let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len();
145153

146-
while let Some(param_start_offset) =
147-
kernel_ptx[next_type_layout..type_layout_middle].find(BEFORE_PARAM_PATTERN)
148-
{
149-
let param_start = next_type_layout + param_start_offset + BEFORE_PARAM_PATTERN.len();
154+
let Some(bytes_end_offset) = kernel_ptx[bytes_start..].find(AFTER_BYTES_PATTERN) else {
155+
abort_call_site!(
156+
"Kernel compilation generated invalid PTX: invalid type layout data"
157+
)
158+
};
159+
let param = &kernel_ptx[param_start..(param_start + len_start_offset)];
160+
let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)];
161+
let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)];
150162

151-
if let Some(len_start_offset) =
152-
kernel_ptx[param_start..type_layout_middle].find(PARAM_LEN_PATTERN)
153-
{
154-
let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN.len();
163+
let param = quote::format_ident!("{}", param);
155164

156-
if let Some(bytes_start_offset) =
157-
kernel_ptx[len_start..type_layout_middle].find(LEN_BYTES_PATTERN)
158-
{
159-
let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN.len();
165+
let Ok(len) = len.parse::<usize>() else {
166+
abort_call_site!(
167+
"Kernel compilation generated invalid PTX: invalid type layout length"
168+
)
169+
};
170+
let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::<Result<Vec<u8>, _>>() else {
171+
abort_call_site!(
172+
"Kernel compilation generated invalid PTX: invalid type layout byte"
173+
)
174+
};
160175

161-
if let Some(bytes_end_offset) =
162-
kernel_ptx[bytes_start..type_layout_middle].find(AFTER_BYTES_PATTERN)
163-
{
164-
let param = &kernel_ptx[param_start..(param_start + len_start_offset)];
165-
let len = &kernel_ptx[len_start..(len_start + bytes_start_offset)];
166-
let bytes = &kernel_ptx[bytes_start..(bytes_start + bytes_end_offset)];
167-
168-
let param = quote::format_ident!("{}", param);
169-
170-
let Ok(len) = len.parse::<usize>() else {
171-
abort_call_site!(
172-
"Kernel compilation generated invalid PTX: invalid type layout length"
173-
)
174-
};
175-
let Ok(bytes) = bytes.split(", ").map(std::str::FromStr::from_str).collect::<Result<Vec<u8>, _>>() else {
176-
abort_call_site!(
177-
"Kernel compilation generated invalid PTX: invalid type layout byte"
178-
)
179-
};
180-
181-
if bytes.len() != len {
182-
abort_call_site!(
183-
"Kernel compilation generated invalid PTX: type layout length \
184-
mismatch"
185-
);
186-
}
187-
188-
let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site());
189-
190-
type_layouts.push(quote! {
191-
const #param: &[u8; #len] = #byte_str;
192-
});
193-
194-
next_type_layout =
195-
bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len();
196-
} else {
197-
next_type_layout = bytes_start;
198-
}
199-
} else {
200-
next_type_layout = len_start;
201-
}
202-
} else {
203-
next_type_layout = param_start;
204-
}
176+
if bytes.len() != len {
177+
abort_call_site!(
178+
"Kernel compilation generated invalid PTX: type layout length mismatch"
179+
);
205180
}
206181

207-
let Some(type_layout_end) = kernel_ptx[type_layout_middle..].find('}').map(|i| {
208-
type_layout_middle + i + '}'.len_utf8()
209-
}) else {
210-
abort_call_site!("Kernel compilation generated invalid PTX")
211-
};
182+
let byte_str = syn::LitByteStr::new(&bytes, proc_macro2::Span::call_site());
183+
184+
type_layouts.push(quote! {
185+
const #param: &[u8; #len] = #byte_str;
186+
});
187+
188+
let type_layout_end = bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN.len();
212189

213190
kernel_ptx.replace_range(type_layout_start..type_layout_end, "");
214191
}
215192

216-
check_kernel_ptx_and_report(
217-
&kernel_ptx,
218-
Specialisation::Link(&specialisation),
219-
&kernel_hash,
220-
&ptx_lint_levels,
221-
);
193+
type_layouts
194+
}
222195

223-
(quote! { const PTX_STR: &'static str = #kernel_ptx; #(#type_layouts)* }).into()
196+
fn remove_kernel_type_use_from_ptx(kernel_ptx: &mut String) {
197+
while let Some(kernel_type_layout_start) = kernel_ptx.find(KERNEL_TYPE_USE_START_CANARY) {
198+
let kernel_type_layout_start = kernel_ptx[..kernel_type_layout_start]
199+
.rfind('\n')
200+
.unwrap_or(kernel_type_layout_start);
201+
202+
let Some(kernel_type_layout_end_offset) = kernel_ptx[
203+
kernel_type_layout_start..
204+
].find(KERNEL_TYPE_USE_END_CANARY) else {
205+
abort_call_site!(
206+
"Kernel compilation generated invalid PTX: incomplete type layout use section"
207+
);
208+
};
209+
210+
let kernel_type_layout_end_offset = kernel_type_layout_end_offset
211+
+ kernel_ptx[kernel_type_layout_start + kernel_type_layout_end_offset..]
212+
.find('\n')
213+
.unwrap_or(KERNEL_TYPE_USE_END_CANARY.len());
214+
215+
let kernel_type_layout_end = kernel_type_layout_start + kernel_type_layout_end_offset;
216+
217+
kernel_ptx.replace_range(kernel_type_layout_start..kernel_type_layout_end, "");
218+
}
224219
}
225220

226221
#[allow(clippy::too_many_lines)]

rust-cuda-derive/src/kernel/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ pub mod wrapper;
44

55
mod lints;
66
mod utils;
7+
8+
const KERNEL_TYPE_USE_START_CANARY: &str = "// <rust-cuda-kernel-param-type-use-start> //";
9+
const KERNEL_TYPE_USE_END_CANARY: &str = "// <rust-cuda-kernel-param-type-use-end> //";

rust-cuda-derive/src/kernel/wrapper/generate/cuda_wrapper.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ use proc_macro2::TokenStream;
22
use quote::quote_spanned;
33
use syn::spanned::Spanned;
44

5-
use super::super::{FuncIdent, FunctionInputs, InputCudaType, KernelConfig};
5+
use super::super::{
6+
super::{KERNEL_TYPE_USE_END_CANARY, KERNEL_TYPE_USE_START_CANARY},
7+
FuncIdent, FunctionInputs, InputCudaType, KernelConfig,
8+
};
69

710
#[allow(clippy::too_many_lines)]
811
pub(in super::super) fn quote_cuda_wrapper(
@@ -96,29 +99,27 @@ pub(in super::super) fn quote_cuda_wrapper(
9699
syn::FnArg::Receiver(_) => unreachable!(),
97100
});
98101

99-
let func_type_layout_ident = quote::format_ident!("{}_type_layout", func_ident);
100-
101102
quote! {
102103
#[cfg(target_os = "cuda")]
103104
#[#crate_path::device::specialise_kernel_entry(#args)]
104105
#[no_mangle]
105106
#(#func_attrs)*
106-
pub unsafe extern "ptx-kernel" fn #func_type_layout_ident(#(#func_params: &mut &[u8]),*) {
107+
pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) {
108+
unsafe {
109+
::core::arch::asm!(#KERNEL_TYPE_USE_START_CANARY);
110+
}
107111
#(
108112
#[no_mangle]
109113
static #func_layout_params: [
110114
u8; #crate_path::const_type_layout::serialised_type_graph_len::<#ptx_func_types>()
111115
] = #crate_path::const_type_layout::serialise_type_graph::<#ptx_func_types>();
112116

113-
*#func_params = &#func_layout_params;
117+
unsafe { ::core::ptr::read_volatile(&#func_layout_params[0]) };
114118
)*
115-
}
119+
unsafe {
120+
::core::arch::asm!(#KERNEL_TYPE_USE_END_CANARY);
121+
}
116122

117-
#[cfg(target_os = "cuda")]
118-
#[#crate_path::device::specialise_kernel_entry(#args)]
119-
#[no_mangle]
120-
#(#func_attrs)*
121-
pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ptx_func_inputs),*) {
122123
#[deny(improper_ctypes)]
123124
mod __rust_cuda_ffi_safe_assert {
124125
use super::#args;

rust-cuda-ptx-jit/src/device.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ macro_rules! PtxJITConstLoad {
55
([$index:literal] => $reference:expr) => {
66
unsafe {
77
::core::arch::asm!(
8-
concat!("// <rust-cuda-ptx-jit-const-load-{}-", $index, "> //"),
8+
::core::concat!("// <rust-cuda-ptx-jit-const-load-{}-", $index, "> //"),
99
in(reg32) *($reference as *const _ as *const u32),
1010
)
1111
}

0 commit comments

Comments
 (0)