1
- use crate :: FnCtxt ;
1
+ use std:: cell:: OnceCell ;
2
+
3
+ use crate :: { errors, FnCtxt } ;
2
4
use rustc_data_structures:: {
3
5
graph:: { self , iterate:: DepthFirstSearch , vec_graph:: VecGraph } ,
4
6
unord:: { UnordBag , UnordMap , UnordSet } ,
5
7
} ;
8
+ use rustc_hir:: HirId ;
6
9
use rustc_infer:: infer:: { DefineOpaqueTypes , InferOk } ;
7
- use rustc_middle:: ty:: { self , Ty } ;
10
+ use rustc_middle:: ty:: { self , Ty , TyCtxt , TypeVisitable } ;
11
+ use rustc_session:: lint;
12
+ use rustc_span:: Span ;
8
13
9
14
#[ derive( Copy , Clone ) ]
10
15
pub enum DivergingFallbackBehavior {
@@ -251,7 +256,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
251
256
252
257
// Construct a coercion graph where an edge `A -> B` indicates
253
258
// a type variable is that is coerced
254
- let coercion_graph = self . create_coercion_graph ( ) ;
259
+ let ( coercion_graph, coercion_graph2 ) = self . create_coercion_graph ( ) ;
255
260
256
261
// Extract the unsolved type inference variable vids; note that some
257
262
// unsolved variables are integer/float variables and are excluded.
@@ -338,6 +343,7 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
338
343
// reach a member of N. If so, it falls back to `()`. Else
339
344
// `!`.
340
345
let mut diverging_fallback = UnordMap :: with_capacity ( diverging_vids. len ( ) ) ;
346
+ let unsafe_infer_vars = OnceCell :: new ( ) ;
341
347
for & diverging_vid in & diverging_vids {
342
348
let diverging_ty = Ty :: new_var ( self . tcx , diverging_vid) ;
343
349
let root_vid = self . root_var ( diverging_vid) ;
@@ -357,11 +363,35 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
357
363
output : infer_var_infos. items ( ) . any ( |info| info. output ) ,
358
364
} ;
359
365
366
+ let mut fallback_to = |ty| {
367
+ let unsafe_infer_vars = unsafe_infer_vars. get_or_init ( || {
368
+ let unsafe_infer_vars = compute_unsafe_infer_vars ( self . root_ctxt , self . body_id ) ;
369
+ debug ! ( ?unsafe_infer_vars) ;
370
+ unsafe_infer_vars
371
+ } ) ;
372
+
373
+ let affected_unsafe_infer_vars =
374
+ graph:: depth_first_search ( & coercion_graph2, root_vid)
375
+ . filter_map ( |x| unsafe_infer_vars. get ( & x) . copied ( ) )
376
+ . collect :: < Vec < _ > > ( ) ;
377
+
378
+ for ( hir_id, span) in affected_unsafe_infer_vars {
379
+ self . tcx . emit_node_span_lint (
380
+ lint:: builtin:: NEVER_TYPE_FALLBACK_FLOWING_INTO_UNSAFE ,
381
+ hir_id,
382
+ span,
383
+ errors:: NeverTypeFallbackFlowingIntoUnsafe { } ,
384
+ ) ;
385
+ }
386
+
387
+ diverging_fallback. insert ( diverging_ty, ty) ;
388
+ } ;
389
+
360
390
use DivergingFallbackBehavior :: * ;
361
391
match behavior {
362
392
FallbackToUnit => {
363
393
debug ! ( "fallback to () - legacy: {:?}" , diverging_vid) ;
364
- diverging_fallback . insert ( diverging_ty , self . tcx . types . unit ) ;
394
+ fallback_to ( self . tcx . types . unit ) ;
365
395
}
366
396
FallbackToNiko => {
367
397
if found_infer_var_info. self_in_trait && found_infer_var_info. output {
@@ -390,21 +420,21 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
390
420
// set, see the relationship finding module in
391
421
// compiler/rustc_trait_selection/src/traits/relationships.rs.
392
422
debug ! ( "fallback to () - found trait and projection: {:?}" , diverging_vid) ;
393
- diverging_fallback . insert ( diverging_ty , self . tcx . types . unit ) ;
423
+ fallback_to ( self . tcx . types . unit ) ;
394
424
} else if can_reach_non_diverging {
395
425
debug ! ( "fallback to () - reached non-diverging: {:?}" , diverging_vid) ;
396
- diverging_fallback . insert ( diverging_ty , self . tcx . types . unit ) ;
426
+ fallback_to ( self . tcx . types . unit ) ;
397
427
} else {
398
428
debug ! ( "fallback to ! - all diverging: {:?}" , diverging_vid) ;
399
- diverging_fallback . insert ( diverging_ty , self . tcx . types . never ) ;
429
+ fallback_to ( self . tcx . types . never ) ;
400
430
}
401
431
}
402
432
FallbackToNever => {
403
433
debug ! (
404
434
"fallback to ! - `rustc_never_type_mode = \" fallback_to_never\" )`: {:?}" ,
405
435
diverging_vid
406
436
) ;
407
- diverging_fallback . insert ( diverging_ty , self . tcx . types . never ) ;
437
+ fallback_to ( self . tcx . types . never ) ;
408
438
}
409
439
NoFallback => {
410
440
debug ! (
@@ -420,7 +450,9 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
420
450
421
451
/// Returns a graph whose nodes are (unresolved) inference variables and where
422
452
/// an edge `?A -> ?B` indicates that the variable `?A` is coerced to `?B`.
423
- fn create_coercion_graph ( & self ) -> VecGraph < ty:: TyVid > {
453
+ ///
454
+ /// The second element of the return tuple is a graph with edges in both directions.
455
+ fn create_coercion_graph ( & self ) -> ( VecGraph < ty:: TyVid > , VecGraph < ty:: TyVid > ) {
424
456
let pending_obligations = self . fulfillment_cx . borrow_mut ( ) . pending_obligations ( ) ;
425
457
debug ! ( "create_coercion_graph: pending_obligations={:?}" , pending_obligations) ;
426
458
let coercion_edges: Vec < ( ty:: TyVid , ty:: TyVid ) > = pending_obligations
@@ -454,11 +486,113 @@ impl<'tcx> FnCtxt<'_, 'tcx> {
454
486
. collect ( ) ;
455
487
debug ! ( "create_coercion_graph: coercion_edges={:?}" , coercion_edges) ;
456
488
let num_ty_vars = self . num_ty_vars ( ) ;
457
- VecGraph :: new ( num_ty_vars, coercion_edges)
489
+
490
+ // This essentially creates a non-directed graph.
491
+ // Ideally we wouldn't do it like this, but it works ig :\
492
+ let doubly_connected = VecGraph :: new (
493
+ num_ty_vars,
494
+ coercion_edges
495
+ . iter ( )
496
+ . copied ( )
497
+ . chain ( coercion_edges. iter ( ) . copied ( ) . map ( |( a, b) | ( b, a) ) )
498
+ . collect ( ) ,
499
+ ) ;
500
+
501
+ let normal = VecGraph :: new ( num_ty_vars, coercion_edges. clone ( ) ) ;
502
+
503
+ ( normal, doubly_connected)
458
504
}
459
505
460
506
/// If `ty` is an unresolved type variable, returns its root vid.
461
507
fn root_vid ( & self , ty : Ty < ' tcx > ) -> Option < ty:: TyVid > {
462
508
Some ( self . root_var ( self . shallow_resolve ( ty) . ty_vid ( ) ?) )
463
509
}
464
510
}
511
+
512
+ /// Finds all type variables which are passed to an `unsafe` function.
513
+ ///
514
+ /// For example, for this function `f`:
515
+ /// ```ignore (demonstrative)
516
+ /// fn f() {
517
+ /// unsafe {
518
+ /// let x /* ?X */ = core::mem::zeroed();
519
+ /// // ^^^^^^^^^^^^^^^^^^^ -- hir_id, span
520
+ ///
521
+ /// let y = core::mem::zeroed::<Option<_ /* ?Y */>>();
522
+ /// // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -- hir_id, span
523
+ /// }
524
+ /// }
525
+ /// ```
526
+ ///
527
+ /// Will return `{ id(?X) -> (hir_id, span) }`
528
+ fn compute_unsafe_infer_vars < ' a , ' tcx > (
529
+ root_ctxt : & ' a crate :: TypeckRootCtxt < ' tcx > ,
530
+ body_id : rustc_span:: def_id:: LocalDefId ,
531
+ ) -> UnordMap < ty:: TyVid , ( HirId , Span ) > {
532
+ use rustc_hir as hir;
533
+
534
+ let tcx = root_ctxt. infcx . tcx ;
535
+ let body_id = tcx. hir ( ) . maybe_body_owned_by ( body_id) . unwrap ( ) ;
536
+ let body = tcx. hir ( ) . body ( body_id) ;
537
+ let mut res = <_ >:: default ( ) ;
538
+
539
+ struct UnsafeInferVarsVisitor < ' a , ' tcx , ' r > {
540
+ root_ctxt : & ' a crate :: TypeckRootCtxt < ' tcx > ,
541
+ res : & ' r mut UnordMap < ty:: TyVid , ( HirId , Span ) > ,
542
+ }
543
+
544
+ use hir:: intravisit:: Visitor ;
545
+ impl hir:: intravisit:: Visitor < ' _ > for UnsafeInferVarsVisitor < ' _ , ' _ , ' _ > {
546
+ fn visit_expr ( & mut self , ex : & ' _ hir:: Expr < ' _ > ) {
547
+ // FIXME: method calls
548
+ if let hir:: ExprKind :: Call ( func, ..) = ex. kind {
549
+ let typeck_results = self . root_ctxt . typeck_results . borrow ( ) ;
550
+
551
+ let func_ty = typeck_results. expr_ty ( func) ;
552
+
553
+ // `is_fn` is required to ignore closures (which can't be unsafe)
554
+ if func_ty. is_fn ( )
555
+ && let sig = func_ty. fn_sig ( self . root_ctxt . infcx . tcx )
556
+ && let hir:: Unsafety :: Unsafe = sig. unsafety ( )
557
+ {
558
+ let mut collector =
559
+ InferVarCollector { hir_id : ex. hir_id , call_span : ex. span , res : self . res } ;
560
+
561
+ // Collect generic arguments of the function which are inference variables
562
+ typeck_results
563
+ . node_args ( ex. hir_id )
564
+ . types ( )
565
+ . for_each ( |t| t. visit_with ( & mut collector) ) ;
566
+
567
+ // Also check the return type, for cases like `(unsafe_fn::<_> as unsafe fn() -> _)()`
568
+ sig. output ( ) . visit_with ( & mut collector) ;
569
+ }
570
+ }
571
+
572
+ hir:: intravisit:: walk_expr ( self , ex) ;
573
+ }
574
+ }
575
+
576
+ struct InferVarCollector < ' r > {
577
+ hir_id : HirId ,
578
+ call_span : Span ,
579
+ res : & ' r mut UnordMap < ty:: TyVid , ( HirId , Span ) > ,
580
+ }
581
+
582
+ impl < ' tcx > ty:: TypeVisitor < TyCtxt < ' tcx > > for InferVarCollector < ' _ > {
583
+ fn visit_ty ( & mut self , t : Ty < ' tcx > ) {
584
+ if let Some ( vid) = t. ty_vid ( ) {
585
+ self . res . insert ( vid, ( self . hir_id , self . call_span ) ) ;
586
+ } else {
587
+ use ty:: TypeSuperVisitable as _;
588
+ t. super_visit_with ( self )
589
+ }
590
+ }
591
+ }
592
+
593
+ UnsafeInferVarsVisitor { root_ctxt, res : & mut res } . visit_expr ( & body. value ) ;
594
+
595
+ debug ! ( ?res, "collected the following unsafe vars for {body_id:?}" ) ;
596
+
597
+ res
598
+ }
0 commit comments