1- use std:: fmt;
1+ use std:: { fmt, ops :: Not } ;
22
33use ast:: HasName ;
44use cfg:: { CfgAtom , CfgExpr } ;
@@ -15,6 +15,7 @@ use ide_db::{
1515 FilePosition , FxHashMap , FxHashSet , RootDatabase , SymbolKind ,
1616} ;
1717use itertools:: Itertools ;
18+ use smallvec:: SmallVec ;
1819use span:: { Edition , TextSize } ;
1920use stdx:: format_to;
2021use syntax:: {
@@ -30,6 +31,7 @@ pub struct Runnable {
3031 pub nav : NavigationTarget ,
3132 pub kind : RunnableKind ,
3233 pub cfg : Option < CfgExpr > ,
34+ pub update_test : UpdateTest ,
3335}
3436
3537#[ derive( Debug , Clone , Hash , PartialEq , Eq ) ]
@@ -334,14 +336,19 @@ pub(crate) fn runnable_fn(
334336 }
335337 } ;
336338
339+ let fn_source = def. source ( sema. db ) ?;
337340 let nav = NavigationTarget :: from_named (
338341 sema. db ,
339- def . source ( sema . db ) ? . as_ref ( ) . map ( |it| it as & dyn ast:: HasName ) ,
342+ fn_source . as_ref ( ) . map ( |it| it as & dyn ast:: HasName ) ,
340343 SymbolKind :: Function ,
341344 )
342345 . call_site ( ) ;
346+
347+ let file_range = fn_source. syntax ( ) . original_file_range_with_macro_call_body ( sema. db ) ;
348+ let update_test = TestDefs :: new ( sema, def. krate ( sema. db ) , file_range) . update_test ( ) ;
349+
343350 let cfg = def. attrs ( sema. db ) . cfg ( ) ;
344- Some ( Runnable { use_name_in_title : false , nav, kind, cfg } )
351+ Some ( Runnable { use_name_in_title : false , nav, kind, cfg, update_test } )
345352}
346353
347354pub ( crate ) fn runnable_mod (
@@ -366,7 +373,22 @@ pub(crate) fn runnable_mod(
366373 let attrs = def. attrs ( sema. db ) ;
367374 let cfg = attrs. cfg ( ) ;
368375 let nav = NavigationTarget :: from_module_to_decl ( sema. db , def) . call_site ( ) ;
369- Some ( Runnable { use_name_in_title : false , nav, kind : RunnableKind :: TestMod { path } , cfg } )
376+
377+ let file_range = {
378+ let src = def. definition_source ( sema. db ) ;
379+ let file_id = src. file_id . original_file ( sema. db ) ;
380+ let range = src. file_syntax ( sema. db ) . text_range ( ) ;
381+ hir:: FileRange { file_id, range }
382+ } ;
383+ let update_test = TestDefs :: new ( sema, def. krate ( ) , file_range) . update_test ( ) ;
384+
385+ Some ( Runnable {
386+ use_name_in_title : false ,
387+ nav,
388+ kind : RunnableKind :: TestMod { path } ,
389+ cfg,
390+ update_test,
391+ } )
370392}
371393
372394pub ( crate ) fn runnable_impl (
@@ -392,7 +414,17 @@ pub(crate) fn runnable_impl(
392414 test_id. retain ( |c| c != ' ' ) ;
393415 let test_id = TestId :: Path ( test_id) ;
394416
395- Some ( Runnable { use_name_in_title : false , nav, kind : RunnableKind :: DocTest { test_id } , cfg } )
417+ let impl_source =
418+ def. source ( sema. db ) ?. syntax ( ) . original_file_range_with_macro_call_body ( sema. db ) ;
419+ let update_test = TestDefs :: new ( sema, def. krate ( sema. db ) , impl_source) . update_test ( ) ;
420+
421+ Some ( Runnable {
422+ use_name_in_title : false ,
423+ nav,
424+ kind : RunnableKind :: DocTest { test_id } ,
425+ cfg,
426+ update_test,
427+ } )
396428}
397429
398430fn has_cfg_test ( attrs : AttrsWithOwner ) -> bool {
@@ -404,6 +436,8 @@ fn runnable_mod_outline_definition(
404436 sema : & Semantics < ' _ , RootDatabase > ,
405437 def : hir:: Module ,
406438) -> Option < Runnable > {
439+ def. as_source_file_id ( sema. db ) ?;
440+
407441 if !has_test_function_or_multiple_test_submodules ( sema, & def, has_cfg_test ( def. attrs ( sema. db ) ) )
408442 {
409443 return None ;
@@ -421,16 +455,22 @@ fn runnable_mod_outline_definition(
421455
422456 let attrs = def. attrs ( sema. db ) ;
423457 let cfg = attrs. cfg ( ) ;
424- if def. as_source_file_id ( sema. db ) . is_some ( ) {
425- Some ( Runnable {
426- use_name_in_title : false ,
427- nav : def. to_nav ( sema. db ) . call_site ( ) ,
428- kind : RunnableKind :: TestMod { path } ,
429- cfg,
430- } )
431- } else {
432- None
433- }
458+
459+ let file_range = {
460+ let src = def. definition_source ( sema. db ) ;
461+ let file_id = src. file_id . original_file ( sema. db ) ;
462+ let range = src. file_syntax ( sema. db ) . text_range ( ) ;
463+ hir:: FileRange { file_id, range }
464+ } ;
465+ let update_test = TestDefs :: new ( sema, def. krate ( ) , file_range) . update_test ( ) ;
466+
467+ Some ( Runnable {
468+ use_name_in_title : false ,
469+ nav : def. to_nav ( sema. db ) . call_site ( ) ,
470+ kind : RunnableKind :: TestMod { path } ,
471+ cfg,
472+ update_test,
473+ } )
434474}
435475
436476fn module_def_doctest ( db : & RootDatabase , def : Definition ) -> Option < Runnable > {
@@ -495,6 +535,7 @@ fn module_def_doctest(db: &RootDatabase, def: Definition) -> Option<Runnable> {
495535 nav,
496536 kind : RunnableKind :: DocTest { test_id } ,
497537 cfg : attrs. cfg ( ) ,
538+ update_test : UpdateTest :: default ( ) ,
498539 } ;
499540 Some ( res)
500541}
@@ -575,6 +616,106 @@ fn has_test_function_or_multiple_test_submodules(
575616 number_of_test_submodules > 1
576617}
577618
619+ struct TestDefs < ' a , ' b > ( & ' a Semantics < ' b , RootDatabase > , hir:: Crate , hir:: FileRange ) ;
620+
621+ #[ derive( Debug , Default , Clone , Copy , PartialEq , Eq , Hash ) ]
622+ pub struct UpdateTest {
623+ pub expect_test : bool ,
624+ pub insta : bool ,
625+ pub snapbox : bool ,
626+ }
627+
628+ impl UpdateTest {
629+ pub fn label ( & self ) -> Option < SmolStr > {
630+ let mut builder: SmallVec < [ _ ; 3 ] > = SmallVec :: new ( ) ;
631+ if self . expect_test {
632+ builder. push ( "Expect" ) ;
633+ }
634+ if self . insta {
635+ builder. push ( "Insta" ) ;
636+ }
637+ if self . snapbox {
638+ builder. push ( "Snapbox" ) ;
639+ }
640+
641+ let res: SmolStr = builder. join ( " + " ) . into ( ) ;
642+ res. is_empty ( ) . not ( ) . then_some ( res)
643+ }
644+ }
645+
646+ impl < ' a , ' b > TestDefs < ' a , ' b > {
647+ fn new (
648+ sema : & ' a Semantics < ' b , RootDatabase > ,
649+ current_krate : hir:: Crate ,
650+ file_range : hir:: FileRange ,
651+ ) -> Self {
652+ Self ( sema, current_krate, file_range)
653+ }
654+
655+ fn update_test ( & self ) -> UpdateTest {
656+ UpdateTest { expect_test : self . expect_test ( ) , insta : self . insta ( ) , snapbox : self . snapbox ( ) }
657+ }
658+
659+ fn expect_test ( & self ) -> bool {
660+ self . find_macro ( "expect_test:expect" ) || self . find_macro ( "expect_test::expect_file" )
661+ }
662+
663+ fn insta ( & self ) -> bool {
664+ self . find_macro ( "insta:assert_snapshot" )
665+ || self . find_macro ( "insta:assert_debug_snapshot" )
666+ || self . find_macro ( "insta:assert_display_snapshot" )
667+ || self . find_macro ( "insta:assert_json_snapshot" )
668+ || self . find_macro ( "insta:assert_yaml_snapshot" )
669+ || self . find_macro ( "insta:assert_ron_snapshot" )
670+ || self . find_macro ( "insta:assert_toml_snapshot" )
671+ || self . find_macro ( "insta:assert_csv_snapshot" )
672+ || self . find_macro ( "insta:assert_compact_json_snapshot" )
673+ || self . find_macro ( "insta:assert_compact_debug_snapshot" )
674+ || self . find_macro ( "insta:assert_binary_snapshot" )
675+ }
676+
677+ fn snapbox ( & self ) -> bool {
678+ self . find_macro ( "snapbox:assert_data_eq" )
679+ || self . find_macro ( "snapbox:file" )
680+ || self . find_macro ( "snapbox:str" )
681+ }
682+
683+ fn find_macro ( & self , path : & str ) -> bool {
684+ let Some ( hir:: ScopeDef :: ModuleDef ( hir:: ModuleDef :: Macro ( it) ) ) = self . find_def ( path) else {
685+ return false ;
686+ } ;
687+
688+ Definition :: Macro ( it)
689+ . usages ( self . 0 )
690+ . in_scope ( & SearchScope :: file_range ( self . 2 ) )
691+ . at_least_one ( )
692+ }
693+
694+ fn find_def ( & self , path : & str ) -> Option < hir:: ScopeDef > {
695+ let db = self . 0 . db ;
696+
697+ let mut path = path. split ( ':' ) ;
698+ let item = path. next_back ( ) ?;
699+ let krate = path. next ( ) ?;
700+ let dep = self . 1 . dependencies ( db) . into_iter ( ) . find ( |dep| dep. name . eq_ident ( krate) ) ?;
701+
702+ let mut module = dep. krate . root_module ( ) ;
703+ for segment in path {
704+ module = module. children ( db) . find_map ( |child| {
705+ let name = child. name ( db) ?;
706+ if name. eq_ident ( segment) {
707+ Some ( child)
708+ } else {
709+ None
710+ }
711+ } ) ?;
712+ }
713+
714+ let ( _, def) = module. scope ( db, None ) . into_iter ( ) . find ( |( name, _) | name. eq_ident ( item) ) ?;
715+ Some ( def)
716+ }
717+ }
718+
578719#[ cfg( test) ]
579720mod tests {
580721 use expect_test:: { expect, Expect } ;
@@ -1337,18 +1478,18 @@ mod tests {
13371478 file_id: FileId(
13381479 0,
13391480 ),
1340- full_range: 52..115 ,
1341- focus_range: 67..75 ,
1342- name: "foo_test ",
1481+ full_range: 121..185 ,
1482+ focus_range: 136..145 ,
1483+ name: "foo2_test ",
13431484 kind: Function,
13441485 },
13451486 NavigationTarget {
13461487 file_id: FileId(
13471488 0,
13481489 ),
1349- full_range: 121..185 ,
1350- focus_range: 136..145 ,
1351- name: "foo2_test ",
1490+ full_range: 52..115 ,
1491+ focus_range: 67..75 ,
1492+ name: "foo_test ",
13521493 kind: Function,
13531494 },
13541495 ]
0 commit comments