@@ -21,6 +21,7 @@ use ptx_builder::{
2121use super :: {
2222 lints:: { LintLevel , PtxLint } ,
2323 utils:: skip_kernel_compilation,
24+ KERNEL_TYPE_USE_END_CANARY , KERNEL_TYPE_USE_START_CANARY ,
2425} ;
2526
2627mod config;
@@ -66,14 +67,14 @@ pub fn check_kernel(tokens: TokenStream) -> TokenStream {
6667 quote ! ( :: core:: result:: Result :: Ok ( ( ) ) ) . into ( )
6768}
6869
69- #[ allow( clippy:: module_name_repetitions, clippy :: too_many_lines ) ]
70+ #[ allow( clippy:: module_name_repetitions) ]
7071pub fn link_kernel ( tokens : TokenStream ) -> TokenStream {
7172 proc_macro_error:: set_dummy ( quote ! {
7273 const PTX_STR : & ' static str = "ERROR in this PTX compilation" ;
7374 } ) ;
7475
7576 let LinkKernelConfig {
76- kernel,
77+ kernel : _kernel ,
7778 kernel_hash,
7879 args,
7980 crate_name,
@@ -111,116 +112,110 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
111112 . into ( )
112113 } ;
113114
114- let kernel_layout_name = if specialisation. is_empty ( ) {
115- format ! ( "{kernel}_type_layout_kernel" )
116- } else {
117- format ! (
118- "{kernel}_type_layout_kernel_{:016x}" ,
119- seahash:: hash( specialisation. as_bytes( ) )
120- )
121- } ;
115+ let type_layouts = extract_ptx_kernel_layout ( & mut kernel_ptx) ;
116+ remove_kernel_type_use_from_ptx ( & mut kernel_ptx) ;
122117
123- let mut type_layouts = Vec :: new ( ) ;
118+ check_kernel_ptx_and_report (
119+ & kernel_ptx,
120+ Specialisation :: Link ( & specialisation) ,
121+ & kernel_hash,
122+ & ptx_lint_levels,
123+ ) ;
124+
125+ ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
126+ }
124127
125- let type_layout_start_pattern = format ! ( "\n \t // .globl\t {kernel_layout_name}" ) ;
128+ fn extract_ptx_kernel_layout ( kernel_ptx : & mut String ) -> Vec < proc_macro2:: TokenStream > {
129+ const BEFORE_PARAM_PATTERN : & str = "\n .global .align 1 .b8 " ;
130+ const PARAM_LEN_PATTERN : & str = "[" ;
131+ const LEN_BYTES_PATTERN : & str = "] = {" ;
132+ const AFTER_BYTES_PATTERN : & str = "};\n " ;
133+ const BYTES_PARAM_PATTERN : & str = "};" ;
126134
127- if let Some ( type_layout_start) = kernel_ptx. find ( & type_layout_start_pattern) {
128- const BEFORE_PARAM_PATTERN : & str = "\n .global .align 1 .b8 " ;
129- const PARAM_LEN_PATTERN : & str = "[" ;
130- const LEN_BYTES_PATTERN : & str = "] = {" ;
131- const AFTER_BYTES_PATTERN : & str = "};\n " ;
132- const BYTES_PARAM_PATTERN : & str = "};" ;
135+ let mut type_layouts = Vec :: new ( ) ;
133136
134- let after_type_layout_start = type_layout_start + type_layout_start_pattern. len ( ) ;
137+ while let Some ( type_layout_start) = kernel_ptx. find ( BEFORE_PARAM_PATTERN ) {
138+ let param_start = type_layout_start + BEFORE_PARAM_PATTERN . len ( ) ;
135139
136- let Some ( type_layout_middle) = kernel_ptx[ after_type_layout_start..]
137- . find ( & format ! ( ".visible .entry {kernel_layout_name}" ) ) . map ( |i| after_type_layout_start + i)
138- else {
140+ let Some ( len_start_offset) = kernel_ptx[ param_start..] . find ( PARAM_LEN_PATTERN ) else {
139141 abort_call_site ! (
140- "Kernel compilation generated invalid PTX: incomplete type layout information "
142+ "Kernel compilation generated invalid PTX: missing type layout data "
141143 )
142144 } ;
145+ let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN . len ( ) ;
143146
144- let mut next_type_layout = after_type_layout_start;
147+ let Some ( bytes_start_offset) = kernel_ptx[ len_start..] . find ( LEN_BYTES_PATTERN ) else {
148+ abort_call_site ! (
149+ "Kernel compilation generated invalid PTX: missing type layout length"
150+ )
151+ } ;
152+ let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN . len ( ) ;
145153
146- while let Some ( param_start_offset) =
147- kernel_ptx[ next_type_layout..type_layout_middle] . find ( BEFORE_PARAM_PATTERN )
148- {
149- let param_start = next_type_layout + param_start_offset + BEFORE_PARAM_PATTERN . len ( ) ;
154+ let Some ( bytes_end_offset) = kernel_ptx[ bytes_start..] . find ( AFTER_BYTES_PATTERN ) else {
155+ abort_call_site ! (
156+ "Kernel compilation generated invalid PTX: invalid type layout data"
157+ )
158+ } ;
159+ let param = & kernel_ptx[ param_start..( param_start + len_start_offset) ] ;
160+ let len = & kernel_ptx[ len_start..( len_start + bytes_start_offset) ] ;
161+ let bytes = & kernel_ptx[ bytes_start..( bytes_start + bytes_end_offset) ] ;
150162
151- if let Some ( len_start_offset) =
152- kernel_ptx[ param_start..type_layout_middle] . find ( PARAM_LEN_PATTERN )
153- {
154- let len_start = param_start + len_start_offset + PARAM_LEN_PATTERN . len ( ) ;
163+ let param = quote:: format_ident!( "{}" , param) ;
155164
156- if let Some ( bytes_start_offset) =
157- kernel_ptx[ len_start..type_layout_middle] . find ( LEN_BYTES_PATTERN )
158- {
159- let bytes_start = len_start + bytes_start_offset + LEN_BYTES_PATTERN . len ( ) ;
165+ let Ok ( len) = len. parse :: < usize > ( ) else {
166+ abort_call_site ! (
167+ "Kernel compilation generated invalid PTX: invalid type layout length"
168+ )
169+ } ;
170+ let Ok ( bytes) = bytes. split ( ", " ) . map ( std:: str:: FromStr :: from_str) . collect :: < Result < Vec < u8 > , _ > > ( ) else {
171+ abort_call_site ! (
172+ "Kernel compilation generated invalid PTX: invalid type layout byte"
173+ )
174+ } ;
160175
161- if let Some ( bytes_end_offset) =
162- kernel_ptx[ bytes_start..type_layout_middle] . find ( AFTER_BYTES_PATTERN )
163- {
164- let param = & kernel_ptx[ param_start..( param_start + len_start_offset) ] ;
165- let len = & kernel_ptx[ len_start..( len_start + bytes_start_offset) ] ;
166- let bytes = & kernel_ptx[ bytes_start..( bytes_start + bytes_end_offset) ] ;
167-
168- let param = quote:: format_ident!( "{}" , param) ;
169-
170- let Ok ( len) = len. parse :: < usize > ( ) else {
171- abort_call_site ! (
172- "Kernel compilation generated invalid PTX: invalid type layout length"
173- )
174- } ;
175- let Ok ( bytes) = bytes. split ( ", " ) . map ( std:: str:: FromStr :: from_str) . collect :: < Result < Vec < u8 > , _ > > ( ) else {
176- abort_call_site ! (
177- "Kernel compilation generated invalid PTX: invalid type layout byte"
178- )
179- } ;
180-
181- if bytes. len ( ) != len {
182- abort_call_site ! (
183- "Kernel compilation generated invalid PTX: type layout length \
184- mismatch"
185- ) ;
186- }
187-
188- let byte_str = syn:: LitByteStr :: new ( & bytes, proc_macro2:: Span :: call_site ( ) ) ;
189-
190- type_layouts. push ( quote ! {
191- const #param: & [ u8 ; #len] = #byte_str;
192- } ) ;
193-
194- next_type_layout =
195- bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN . len ( ) ;
196- } else {
197- next_type_layout = bytes_start;
198- }
199- } else {
200- next_type_layout = len_start;
201- }
202- } else {
203- next_type_layout = param_start;
204- }
176+ if bytes. len ( ) != len {
177+ abort_call_site ! (
178+ "Kernel compilation generated invalid PTX: type layout length mismatch"
179+ ) ;
205180 }
206181
207- let Some ( type_layout_end) = kernel_ptx[ type_layout_middle..] . find ( '}' ) . map ( |i| {
208- type_layout_middle + i + '}' . len_utf8 ( )
209- } ) else {
210- abort_call_site ! ( "Kernel compilation generated invalid PTX" )
211- } ;
182+ let byte_str = syn:: LitByteStr :: new ( & bytes, proc_macro2:: Span :: call_site ( ) ) ;
183+
184+ type_layouts. push ( quote ! {
185+ const #param: & [ u8 ; #len] = #byte_str;
186+ } ) ;
187+
188+ let type_layout_end = bytes_start + bytes_end_offset + BYTES_PARAM_PATTERN . len ( ) ;
212189
213190 kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
214191 }
215192
216- check_kernel_ptx_and_report (
217- & kernel_ptx,
218- Specialisation :: Link ( & specialisation) ,
219- & kernel_hash,
220- & ptx_lint_levels,
221- ) ;
193+ type_layouts
194+ }
222195
223- ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
196+ fn remove_kernel_type_use_from_ptx ( kernel_ptx : & mut String ) {
197+ while let Some ( kernel_type_layout_start) = kernel_ptx. find ( KERNEL_TYPE_USE_START_CANARY ) {
198+ let kernel_type_layout_start = kernel_ptx[ ..kernel_type_layout_start]
199+ . rfind ( '\n' )
200+ . unwrap_or ( kernel_type_layout_start) ;
201+
202+ let Some ( kernel_type_layout_end_offset) = kernel_ptx[
203+ kernel_type_layout_start..
204+ ] . find ( KERNEL_TYPE_USE_END_CANARY ) else {
205+ abort_call_site ! (
206+ "Kernel compilation generated invalid PTX: incomplete type layout use section"
207+ ) ;
208+ } ;
209+
210+ let kernel_type_layout_end_offset = kernel_type_layout_end_offset
211+ + kernel_ptx[ kernel_type_layout_start + kernel_type_layout_end_offset..]
212+ . find ( '\n' )
213+ . unwrap_or ( KERNEL_TYPE_USE_END_CANARY . len ( ) ) ;
214+
215+ let kernel_type_layout_end = kernel_type_layout_start + kernel_type_layout_end_offset;
216+
217+ kernel_ptx. replace_range ( kernel_type_layout_start..kernel_type_layout_end, "" ) ;
218+ }
224219}
225220
226221#[ allow( clippy:: too_many_lines) ]
0 commit comments