@@ -32,31 +32,38 @@ use error::emit_ptx_build_error;
3232use ptx_compiler_sys:: NvptxError ;
3333
3434pub fn check_kernel ( tokens : TokenStream ) -> TokenStream {
35- proc_macro_error:: set_dummy ( quote ! {
36- "ERROR in this PTX compilation"
37- } ) ;
35+ proc_macro_error:: set_dummy ( quote ! { :: core:: result:: Result :: Err ( ( ) ) } ) ;
3836
3937 let CheckKernelConfig {
38+ kernel_hash,
4039 args,
4140 crate_name,
4241 crate_path,
4342 } = match syn:: parse_macro_input:: parse ( tokens) {
4443 Ok ( config) => config,
4544 Err ( err) => {
4645 abort_call_site ! (
47- "check_kernel!(ARGS NAME PATH) expects ARGS identifier, NAME and PATH string \
48- literals: {:?}",
46+ "check_kernel!(HASH ARGS NAME PATH) expects HASH and ARGS identifiers, annd NAME \
47+ and PATH string literals: {:?}",
4948 err
5049 )
5150 } ,
5251 } ;
5352
5453 let kernel_ptx = compile_kernel ( & args, & crate_name, & crate_path, Specialisation :: Check ) ;
5554
56- match kernel_ptx {
57- Some ( kernel_ptx) => quote ! ( #kernel_ptx) . into ( ) ,
58- None => quote ! ( "ERROR in this PTX compilation" ) . into ( ) ,
59- }
55+ let Some ( kernel_ptx) = kernel_ptx else {
56+ return quote ! ( :: core:: result:: Result :: Err ( ( ) ) ) . into ( )
57+ } ;
58+
59+ check_kernel_ptx_and_report (
60+ & kernel_ptx,
61+ Specialisation :: Check ,
62+ & kernel_hash,
63+ & HashMap :: new ( ) ,
64+ ) ;
65+
66+ quote ! ( :: core:: result:: Result :: Ok ( ( ) ) ) . into ( )
6067}
6168
6269#[ allow( clippy:: module_name_repetitions, clippy:: too_many_lines) ]
@@ -77,9 +84,9 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
7784 Ok ( config) => config,
7885 Err ( err) => {
7986 abort_call_site ! (
80- "link_kernel!(KERNEL ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL and \
81- ARGS identifiers, NAME and PATH string literals, SPECIALISATION and LINTS \
82- tokens: {:?}",
87+ "link_kernel!(KERNEL HASH ARGS NAME PATH SPECIALISATION LINTS,*) expects KERNEL, \
88+ HASH, and ARGS identifiers, NAME and PATH string literals, and SPECIALISATION \
89+ and LINTS tokens: {:?}",
8390 err
8491 )
8592 } ,
@@ -206,88 +213,162 @@ pub fn link_kernel(tokens: TokenStream) -> TokenStream {
206213 kernel_ptx. replace_range ( type_layout_start..type_layout_end, "" ) ;
207214 }
208215
209- let ( result, error_log, info_log, version, drop) =
210- check_kernel_ptx ( & kernel_ptx, & specialisation, & kernel_hash, & ptx_lint_levels) ;
216+ check_kernel_ptx_and_report (
217+ & kernel_ptx,
218+ Specialisation :: Link ( & specialisation) ,
219+ & kernel_hash,
220+ & ptx_lint_levels,
221+ ) ;
222+
223+ ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
224+ }
225+
226+ #[ allow( clippy:: too_many_lines) ]
227+ fn check_kernel_ptx_and_report (
228+ kernel_ptx : & str ,
229+ specialisation : Specialisation ,
230+ kernel_hash : & proc_macro2:: Ident ,
231+ ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
232+ ) {
233+ let ( result, error_log, info_log, binary, version, drop) =
234+ check_kernel_ptx ( kernel_ptx, specialisation, kernel_hash, ptx_lint_levels) ;
211235
212236 let ptx_compiler = match & version {
213237 Ok ( ( major, minor) ) => format ! ( "PTX compiler v{major}.{minor}" ) ,
214238 Err ( _) => String :: from ( "PTX compiler" ) ,
215239 } ;
216240
217- // TODO: allow user to select
218- // - warn on double
219- // - warn on float
220- // - warn on spills
221- // - verbose warn
222- // - warnings as errors
223- // - show PTX source if warning or error
224-
225241 let mut errors = String :: new ( ) ;
242+
226243 if let Err ( err) = drop {
227244 let _ = errors. write_fmt ( format_args ! ( "Error dropping the {ptx_compiler}: {err}\n " ) ) ;
228245 }
246+
229247 if let Err ( err) = version {
230248 let _ = errors. write_fmt ( format_args ! (
231249 "Error fetching the version of the {ptx_compiler}: {err}\n "
232250 ) ) ;
233251 }
234- if let ( Ok ( Some ( _) ) , _) | ( _, Ok ( Some ( _) ) ) = ( & info_log, & error_log) {
252+
253+ let ptx_source_code = {
235254 let mut max_lines = kernel_ptx. chars ( ) . filter ( |c| * c == '\n' ) . count ( ) + 1 ;
236255 let mut indent = 0 ;
237256 while max_lines > 0 {
238257 max_lines /= 10 ;
239258 indent += 1 ;
240259 }
241260
242- emit_call_site_warning ! (
261+ format ! (
243262 "PTX source code:\n {}" ,
244263 kernel_ptx
245264 . lines( )
246265 . enumerate( )
247266 . map( |( i, l) | format!( "{:indent$}| {l}" , i + 1 ) )
248267 . collect:: <Vec <_>>( )
249268 . join( "\n " )
250- ) ;
269+ )
270+ } ;
271+
272+ match binary {
273+ Ok ( None ) => ( ) ,
274+ Ok ( Some ( binary) ) => {
275+ if ptx_lint_levels
276+ . get ( & PtxLint :: DumpBinary )
277+ . map_or ( false , |level| * level > LintLevel :: Allow )
278+ {
279+ const HEX : [ char ; 16 ] = [
280+ '0' , '1' , '2' , '3' , '4' , '5' , '6' , '7' , '8' , '9' , 'a' , 'b' , 'c' , 'd' , 'e' , 'f' ,
281+ ] ;
282+
283+ let mut binary_hex = String :: with_capacity ( binary. len ( ) * 2 ) ;
284+ for byte in binary {
285+ binary_hex. push ( HEX [ usize:: from ( byte >> 4 ) ] ) ;
286+ binary_hex. push ( HEX [ usize:: from ( byte & 0x0F ) ] ) ;
287+ }
288+
289+ if ptx_lint_levels
290+ . get ( & PtxLint :: DumpBinary )
291+ . map_or ( false , |level| * level > LintLevel :: Warn )
292+ {
293+ emit_call_site_error ! (
294+ "{} compiled binary:\n {}\n \n {}" ,
295+ ptx_compiler,
296+ binary_hex,
297+ ptx_source_code
298+ ) ;
299+ } else {
300+ emit_call_site_warning ! (
301+ "{} compiled binary:\n {}\n \n {}" ,
302+ ptx_compiler,
303+ binary_hex,
304+ ptx_source_code
305+ ) ;
306+ }
307+ }
308+ } ,
309+ Err ( err) => {
310+ let _ = errors. write_fmt ( format_args ! (
311+ "Error fetching the compiled binary from {ptx_compiler}: {err}\n "
312+ ) ) ;
313+ } ,
251314 }
315+
252316 match info_log {
253317 Ok ( None ) => ( ) ,
254- Ok ( Some ( info_log) ) => emit_call_site_warning ! ( "{ptx_compiler} info log:\n {}" , info_log) ,
318+ Ok ( Some ( info_log) ) => emit_call_site_warning ! (
319+ "{} info log:\n {}\n {}" ,
320+ ptx_compiler,
321+ info_log,
322+ ptx_source_code
323+ ) ,
255324 Err ( err) => {
256325 let _ = errors. write_fmt ( format_args ! (
257326 "Error fetching the info log of the {ptx_compiler}: {err}\n "
258327 ) ) ;
259328 } ,
260329 } ;
261- match error_log {
262- Ok ( None ) => ( ) ,
263- Ok ( Some ( error_log) ) => emit_call_site_error ! ( "{ptx_compiler} error log:\n {}" , error_log) ,
330+
331+ let error_log = match error_log {
332+ Ok ( None ) => String :: new ( ) ,
333+ Ok ( Some ( error_log) ) => {
334+ format ! ( "{ptx_compiler} error log:\n {error_log}\n {ptx_source_code}" )
335+ } ,
264336 Err ( err) => {
265337 let _ = errors. write_fmt ( format_args ! (
266338 "Error fetching the error log of the {ptx_compiler}: {err}\n "
267339 ) ) ;
340+ String :: new ( )
268341 } ,
269342 } ;
343+
270344 if let Err ( err) = result {
271345 let _ = errors. write_fmt ( format_args ! ( "Error compiling the PTX source code: {err}\n " ) ) ;
272346 }
273- if !errors. is_empty ( ) {
274- abort_call_site ! ( "{}" , errors) ;
275- }
276347
277- ( quote ! { const PTX_STR : & ' static str = #kernel_ptx; #( #type_layouts) * } ) . into ( )
348+ if !error_log. is_empty ( ) || !errors. is_empty ( ) {
349+ abort_call_site ! (
350+ "{error_log}{}{errors}" ,
351+ if !error_log. is_empty( ) && !errors. is_empty( ) {
352+ "\n \n "
353+ } else {
354+ ""
355+ }
356+ ) ;
357+ }
278358}
279359
280360#[ allow( clippy:: type_complexity) ]
281361#[ allow( clippy:: too_many_lines) ]
282362fn check_kernel_ptx (
283363 kernel_ptx : & str ,
284- specialisation : & str ,
364+ specialisation : Specialisation ,
285365 kernel_hash : & proc_macro2:: Ident ,
286366 ptx_lint_levels : & HashMap < PtxLint , LintLevel > ,
287367) -> (
288368 Result < ( ) , NvptxError > ,
289369 Result < Option < String > , NvptxError > ,
290370 Result < Option < String > , NvptxError > ,
371+ Result < Option < Vec < u8 > > , NvptxError > ,
291372 Result < ( u32 , u32 ) , NvptxError > ,
292373 Result < ( ) , NvptxError > ,
293374) {
@@ -306,14 +387,15 @@ fn check_kernel_ptx(
306387 } ;
307388
308389 let result = ( || {
309- let kernel_name = if specialisation. is_empty ( ) {
310- format ! ( "{kernel_hash}_kernel" )
311- } else {
312- format ! (
390+ let kernel_name = match specialisation {
391+ Specialisation :: Check => format ! ( "{kernel_hash}_chECK" ) ,
392+ Specialisation :: Link ( "" ) => format ! ( "{kernel_hash}_kernel" ) ,
393+ Specialisation :: Link ( specialisation ) => format ! (
313394 "{kernel_hash}_kernel_{:016x}" ,
314395 seahash:: hash( specialisation. as_bytes( ) )
315- )
396+ ) ,
316397 } ;
398+
317399 let mut options = vec ! [
318400 CString :: new( "--entry" ) . unwrap( ) ,
319401 CString :: new( kernel_name) . unwrap( ) ,
@@ -450,6 +532,39 @@ fn check_kernel_ptx(
450532 Ok ( Some ( String :: from_utf8_lossy ( & info_log) . into_owned ( ) ) )
451533 } ) ( ) ;
452534
535+ let binary = ( || {
536+ if result. is_err ( ) {
537+ return Ok ( None ) ;
538+ }
539+
540+ let mut binary_size = 0 ;
541+
542+ NvptxError :: try_err_from ( unsafe {
543+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgramSize (
544+ compiler,
545+ addr_of_mut ! ( binary_size) ,
546+ )
547+ } ) ?;
548+
549+ if binary_size == 0 {
550+ return Ok ( None ) ;
551+ }
552+
553+ #[ allow( clippy:: cast_possible_truncation) ]
554+ let mut binary: Vec < u8 > = Vec :: with_capacity ( binary_size as usize ) ;
555+
556+ NvptxError :: try_err_from ( unsafe {
557+ ptx_compiler_sys:: nvPTXCompilerGetCompiledProgram ( compiler, binary. as_mut_ptr ( ) . cast ( ) )
558+ } ) ?;
559+
560+ #[ allow( clippy:: cast_possible_truncation) ]
561+ unsafe {
562+ binary. set_len ( binary_size as usize ) ;
563+ }
564+
565+ Ok ( Some ( binary) )
566+ } ) ( ) ;
567+
453568 let version = ( || {
454569 let mut major = 0 ;
455570 let mut minor = 0 ;
@@ -468,7 +583,7 @@ fn check_kernel_ptx(
468583 } )
469584 } ;
470585
471- ( result, error_log, info_log, version, drop)
586+ ( result, error_log, info_log, binary , version, drop)
472587}
473588
474589fn compile_kernel (
0 commit comments