diff --git a/crates/ty_ide/src/hover.rs b/crates/ty_ide/src/hover.rs index 1c2db06ca3e86..8b58f840be22d 100644 --- a/crates/ty_ide/src/hover.rs +++ b/crates/ty_ide/src/hover.rs @@ -303,7 +303,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" def my_func( a, b @@ -358,7 +358,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" def my_func( a, b @@ -427,7 +427,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" --------------------------------------------- This is such a great class!! @@ -489,7 +489,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" --------------------------------------------- This is such a great class!! @@ -664,7 +664,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" --------------------------------------------- This is such a great class!! @@ -729,7 +729,7 @@ mod tests { "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" bound method MyClass.my_method( a, b @@ -2045,7 +2045,7 @@ def ab(a: int, *, c: int): ) .unwrap(); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" --------------------------------------------- The cool lib_py module! @@ -2599,7 +2599,7 @@ def function(): ) .unwrap(); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" --------------------------------------------- The cool lib_py module! @@ -2999,7 +2999,7 @@ def function(): "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" int --------------------------------------------- This is the docs for this value @@ -3088,7 +3088,7 @@ def function(): "#, ); - assert_snapshot!(test.hover(), @" + assert_snapshot!(test.hover(), @r" int --------------------------------------------- This is the docs for this value @@ -4588,6 +4588,115 @@ def function(): "); } + #[test] + fn hover_multi_inference() { + let test = cursor_test( + r#" + def list1[T](x: T) -> list[T]: + return [x] + + def f(x: int, y: int) -> list[int] | list[str]: + return list1(x + y) + "#, + ); + + assert_snapshot!(test.hover(), @r" + int + --------------------------------------------- + ```python + int + ``` + --------------------------------------------- + info[hover]: Hovered content is + --> main.py:6:18 + | + 5 | def f(x: int, y: int) -> list[int] | list[str]: + 6 | return list1(x + y) + | ^- Cursor offset + | | + | source + | + "); + + let test = cursor_test( + r#" + def f(x: int, y: int) -> list[int] | list[str]: + return [x + y] + "#, + ); + + assert_snapshot!(test.hover(), @r" + int + --------------------------------------------- + ```python + int + ``` + --------------------------------------------- + info[hover]: Hovered content is + --> main.py:3:13 + | + 2 | def f(x: int, y: int) -> list[int] | list[str]: + 3 | return [x + y] + | ^- Cursor offset + | | + | source + | + "); + + let test = cursor_test( + r#" + def list1[T](x: T) -> list[T]: + return [x] + + def f(x: int, y: int) -> list[int] | list[str]: + return (_ := list1(x + y)) + "#, + ); + + assert_snapshot!(test.hover(), @r" + list[int] + --------------------------------------------- + ```python + list[int] + ``` + --------------------------------------------- + info[hover]: Hovered content is + --> main.py:6:13 + | + 5 | def f(x: int, y: int) -> list[int] | list[str]: + 6 | return (_ := list1(x + y)) + | ^- Cursor offset + | | + | source + | + "); + + let test = cursor_test( + r#" + def f(x: int, y: int) -> list[int] | list[str]: + return (_ := [x + y]) + "#, + ); + + assert_snapshot!(test.hover(), @r" + list[int] + --------------------------------------------- + ```python + list[int] + ``` + --------------------------------------------- + info[hover]: Hovered content is + --> main.py:3:13 + | + 2 | def f(x: int, y: int) -> list[int] | list[str]: + 3 | return (_ := [x + y]) + | ^- Cursor offset + | | + | source + | + "); + } + #[test] fn hover_submodule_import_from_use() { let test = CursorTest::builder() diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index a5c8fed0aae7b..bb717b4eb09ba 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4977,27 +4977,22 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; let mut try_narrow = |narrowed_ty| { - let mut speculated_bindings = bindings.clone(); let narrowed_tcx = TypeContext::new(Some(narrowed_ty)); - // We silence diagnostics until we successfully narrow to a specific type. - let was_in_multi_inference = self.context.set_multi_inference(true); + let mut speculative_bindings = bindings.clone(); + let mut speculative_builder = self.speculate(); // Attempt to infer the argument types using the narrowed type context. - self.infer_all_argument_types( + speculative_builder.infer_all_argument_types( ast_arguments.clone(), argument_types, infer_argument_ty, bindings, narrowed_tcx, - MultiInferenceState::Ignore, ); - // Restore the multi-inference state. - self.context.set_multi_inference(was_in_multi_inference); - // Ensure the argument types match their annotated types. - if speculated_bindings + if speculative_bindings .check_types_impl( db, &constraints, @@ -5007,35 +5002,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) .is_err() { + speculative_builder.discard(); return None; } - // Ensure the inferred return type is assignable to the (narrowed) declared type. + // Ensure the inferred return type is assignable to the narrowed declared type. // // TODO: Checking assignability against the full declared type could help avoid // cases where the constraint solver is not smart enough to solve complex unions. // We should see revisit this after the new constraint solver is implemented. - if !speculated_bindings + if !speculative_bindings .return_type(db) .is_assignable_to(db, narrowed_ty) { + speculative_builder.discard(); return None; } // Successfully narrowed to an element of the union. - // - // If necessary, infer the argument types again with diagnostics enabled. - if !was_in_multi_inference { - self.infer_all_argument_types( - ast_arguments.clone(), - argument_types, - infer_argument_ty, - bindings, - narrowed_tcx, - MultiInferenceState::Intersect, - ); - } - + self.extend(speculative_builder); Some(bindings.check_types_impl( db, &constraints, @@ -5075,7 +5060,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { infer_argument_ty, bindings, call_expression_tcx, - MultiInferenceState::Intersect, ); bindings.check_types_impl( @@ -5099,7 +5083,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { infer_argument_ty: &mut dyn FnMut(&mut Self, ArgExpr<'db, '_>) -> Type<'db>, bindings: &Bindings<'db>, call_expression_tcx: TypeContext<'db>, - multi_inference_state: MultiInferenceState, ) { debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len()); @@ -5135,7 +5118,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .flatten() .collect::>(); - let old_multi_inference_state = self.set_multi_inference_state(multi_inference_state); + // Each type is a valid independent inference of the given argument, and we may require + // different permutations of argument types to correctly perform argument expansion during + // overload evaluation, so we take the intersection of all the types we inferred for each + // argument. + let old_multi_inference_state = + self.set_multi_inference_state(MultiInferenceState::Intersect); for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { let ast_argument = match ast_argument { @@ -5258,11 +5246,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { continue; } - // Each type is a valid independent inference of the given argument, and we may require different - // permutations of argument types to correctly perform argument expansion during overload evaluation, - // so we take the intersection of all the types we inferred for each argument. - // - // TODO: intersecting the inferred argument types is correct for unions of + // TODO: Intersecting the inferred argument types is correct for unions of // callables, since the argument must satisfy each callable, but it's not clear // that it's correct for an intersection of callables, or for a case where // different overloads provide different type context; unioning may be more @@ -5968,41 +5952,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; let mut try_narrow = |narrowed_ty| { - let narrowed_tcx = TypeContext::new(Some(narrowed_ty)); - - // We silence diagnostics until we successfully narrow to a specific type. - let prev_multi_inference = self.set_multi_inference_state(MultiInferenceState::Ignore); - let was_in_multi_inference = self.context.set_multi_inference(true); + let mut speculative_builder = self.speculate(); // Attempt to infer the collection literal using the narrowed type context. - let inferred_ty = self.infer_collection_literal_impl( + let inferred_ty = speculative_builder.infer_collection_literal_impl( collection_class, elts, infer_elt_expression, - narrowed_tcx, + TypeContext::new(Some(narrowed_ty)), )?; - // Restore the multi-inference state. - self.context.set_multi_inference(was_in_multi_inference); - self.set_multi_inference_state(prev_multi_inference); - - // Ensure the inferred return type is assignable to the (narrowed) declared type. + // Ensure the inferred return type is assignable to the narrowed declared type. if !inferred_ty.is_assignable_to(db, narrowed_ty) { + speculative_builder.discard(); return None; } // Successfully narrowed to an element of the union. - // - // If necessary, infer the collection literal again with diagnostics enabled. - if !was_in_multi_inference { - self.infer_collection_literal_impl( - collection_class, - elts, - infer_elt_expression, - narrowed_tcx, - ); - } - + self.extend(speculative_builder); Some(inferred_ty) }; @@ -8854,6 +8821,79 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ScopeInference { expressions, extra } } + + /// Returns a fresh [`TypeInferenceBuilder`] for the current scope that can be used + /// to speculatively infer expressions during multi-inference. + /// + /// The inference results can be merged into the current inference region using + /// [`TypeInferenceBuilder::extend`], or ignored using [`TypeInferenceBuilder::discard`]. + fn speculate(&mut self) -> Self { + TypeInferenceBuilder::new(self.db(), self.region, self.index, self.module()) + } + + /// Extend the current region with the results of a speculative [`TypeInferenceBuilder`]. + fn extend(&mut self, other: Self) { + let Self { + context, + expressions, + string_annotations, + scope, + bindings, + declarations, + deferred, + cycle_recovery, + dataclass_field_specifiers: _, + + // Ignored; only relevant to definition regions + undecorated_type: _, + + // builder only state + all_definitely_bound: _, + typevar_binding_context: _, + inference_flags: _, + deferred_state: _, + multi_inference_state: _, + inner_expression_inference_state: _, + inferring_vararg_annotation: _, + called_functions: _, + index: _, + region: _, + return_types_and_ranges: _, + } = other; + + let diagnostics = context.finish(); + let _ = scope; + + assert!( + declarations.is_empty(), + "speculative `TypeInferenceBuilder` should only be used for expression inference" + ); + assert!( + deferred.is_empty(), + "speculative `TypeInferenceBuilder` should only be used for expression inference" + ); + + self.expressions.extend(expressions.iter()); + self.context.extend(&diagnostics); + self.extend_cycle_recovery(cycle_recovery); + self.string_annotations + .extend(string_annotations.iter().copied()); + + if !matches!(self.region, InferenceRegion::Scope(..)) { + self.bindings.extend( + bindings.iter().map(|(def, ty)| (*def, *ty)), + self.multi_inference_state, + ); + } + } + + /// Ignore the results of this [`TypeInferenceBuilder`]. + /// + /// Note that dropping a [`TypeInferenceBuilder`] without calling this method will result + /// in a panic. + fn discard(self) { + let _ = self.context.finish(); + } } /// Manages the inference of a given expression.