Skip to content

Commit

Permalink
Auto merge of #18016 - IvarWithoutBones:wrap-return-ty-local-result, …
Browse files Browse the repository at this point in the history
…r=Veykril

fix: use Result type aliases in "Wrap return type in Result" assist

This commit makes the "Wrap return type in Result" assist prefer type aliases of standard library type when the are in scope, use at least one generic parameter, and have the name `Result`.

The last restriction was made in an attempt to avoid false assumptions about which type the user is referring to, but that might be overly strict. We could also do something like this, in order of priority:
* Use the alias named "Result".
* Use any alias if only a single one is in scope, otherwise:
* Use the standard library type.

This is easy to add if others feel differently that is appropriate, just let me know.

Fixes #17796
  • Loading branch information
bors committed Sep 2, 2024
2 parents 6341e88 + 4e9d17b commit d534cc6
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 16 deletions.
17 changes: 14 additions & 3 deletions crates/hir/src/semantics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use hir_def::{
hir::Expr,
lower::LowerCtx,
nameres::MacroSubNs,
path::ModPath,
resolver::{self, HasResolver, Resolver, TypeNs},
type_ref::Mutability,
AsMacroCall, DefWithBodyId, FunctionId, MacroId, TraitId, VariantId,
Expand Down Expand Up @@ -46,9 +47,9 @@ use crate::{
source_analyzer::{resolve_hir_path, SourceAnalyzer},
Access, Adjust, Adjustment, Adt, AutoBorrow, BindingMode, BuiltinAttr, Callable, Const,
ConstParam, Crate, DeriveHelper, Enum, Field, Function, HasSource, HirFileId, Impl, InFile,
Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, ScopeDef,
Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, TypeParam, Union,
Variant, VariantDef,
ItemInNs, Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path,
ScopeDef, Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias,
TypeParam, Union, Variant, VariantDef,
};

const CONTINUE_NO_BREAKS: ControlFlow<Infallible, ()> = ControlFlow::Continue(());
Expand Down Expand Up @@ -1384,6 +1385,16 @@ impl<'db> SemanticsImpl<'db> {
self.analyze(path.syntax())?.resolve_path(self.db, path)
}

pub fn resolve_mod_path(
&self,
scope: &SyntaxNode,
path: &ModPath,
) -> Option<impl Iterator<Item = ItemInNs>> {
let analyze = self.analyze(scope)?;
let items = analyze.resolver.resolve_module_path_in_items(self.db.upcast(), path);
Some(items.iter_items().map(|(item, _)| item.into()))
}

fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option<VariantId> {
self.analyze(record_lit.syntax())?.resolve_variant(self.db, record_lit)
}
Expand Down
293 changes: 280 additions & 13 deletions crates/ide-assists/src/handlers/wrap_return_type_in_result.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::iter;

use hir::HasSource;
use ide_db::{
famous_defs::FamousDefs,
syntax_helpers::node_ext::{for_each_tail_expr, walk_expr},
};
use itertools::Itertools;
use syntax::{
ast::{self, make, Expr},
match_ast, ted, AstNode,
ast::{self, make, Expr, HasGenericParams},
match_ast, ted, AstNode, ToSmolStr,
};

use crate::{AssistContext, AssistId, AssistKind, Assists};
Expand Down Expand Up @@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
};

let type_ref = &ret_type.ty()?;
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
let result_enum =
let core_result =
FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;

if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_result) {
// The return type is already wrapped in a Result
cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result);
return None;
}

let new_result_ty =
make::ext::ty_result(type_ref.clone(), make::ty_placeholder()).clone_for_update();
let generic_args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast)?;
let last_genarg = generic_args.generic_args().last()?;

acc.add(
AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite),
"Wrap return type in Result",
type_ref.syntax().text_range(),
|edit| {
let new_result_ty = result_type(ctx, &core_result, type_ref).clone_for_update();
let body = edit.make_mut(ast::Expr::BlockExpr(body));

let mut exprs_to_wrap = Vec::new();
Expand All @@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext<
}

let old_result_ty = edit.make_mut(type_ref.clone());

ted::replace(old_result_ty.syntax(), new_result_ty.syntax());

if let Some(cap) = ctx.config.snippet_cap {
edit.add_placeholder_snippet(cap, last_genarg);
// Add a placeholder snippet at the first generic argument that doesn't equal the return type.
// This is normally the error type, but that may not be the case when we inserted a type alias.
let args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast);
let error_type_arg = args.and_then(|list| {
list.generic_args().find(|arg| match arg {
ast::GenericArg::TypeArg(_) => arg.syntax().text() != type_ref.syntax().text(),
ast::GenericArg::LifetimeArg(_) => false,
_ => true,
})
});
if let Some(error_type_arg) = error_type_arg {
if let Some(cap) = ctx.config.snippet_cap {
edit.add_placeholder_snippet(cap, error_type_arg);
}
}
},
)
}

fn result_type(
ctx: &AssistContext<'_>,
core_result: &hir::Enum,
ret_type: &ast::Type,
) -> ast::Type {
// Try to find a Result<T, ...> type alias in the current scope (shadowing the default).
let result_path = hir::ModPath::from_segments(
hir::PathKind::Plain,
iter::once(hir::Name::new_symbol_root(hir::sym::Result.clone())),
);
let alias = ctx.sema.resolve_mod_path(ret_type.syntax(), &result_path).and_then(|def| {
def.filter_map(|def| match def.as_module_def()? {
hir::ModuleDef::TypeAlias(alias) => {
let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
(&enum_ty == core_result).then_some(alias)
}
_ => None,
})
.find_map(|alias| {
let mut inserted_ret_type = false;
let generic_params = alias
.source(ctx.db())?
.value
.generic_param_list()?
.generic_params()
.map(|param| match param {
// Replace the very first type parameter with the functions return type.
ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
inserted_ret_type = true;
ret_type.to_smolstr()
}
ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(),
_ => make::ty_placeholder().to_smolstr(),
})
.join(", ");

let name = alias.name(ctx.db());
let name = name.as_str();
Some(make::ty(&format!("{name}<{generic_params}>")))
})
});
// If there is no applicable alias in scope use the default Result type.
alias.unwrap_or_else(|| make::ext::ty_result(ret_type.clone(), make::ty_placeholder()))
}

fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
match e {
Expr::BreakExpr(break_expr) => {
Expand Down Expand Up @@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result<u32, ${0:_}> {
"#,
);
}

#[test]
fn wrap_return_type_in_local_result_type() {
check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result<T> = core::result::Result<T, ()>;
fn foo() -> i3$02 {
return 42i32;
}
"#,
r#"
type Result<T> = core::result::Result<T, ()>;
fn foo() -> Result<i32> {
return Ok(42i32);
}
"#,
);

check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result2<T> = core::result::Result<T, ()>;
fn foo() -> i3$02 {
return 42i32;
}
"#,
r#"
type Result2<T> = core::result::Result<T, ()>;
fn foo() -> Result<i32, ${0:_}> {
return Ok(42i32);
}
"#,
);
}

#[test]
fn wrap_return_type_in_imported_local_result_type() {
check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
mod some_module {
pub type Result<T> = core::result::Result<T, ()>;
}
use some_module::Result;
fn foo() -> i3$02 {
return 42i32;
}
"#,
r#"
mod some_module {
pub type Result<T> = core::result::Result<T, ()>;
}
use some_module::Result;
fn foo() -> Result<i32> {
return Ok(42i32);
}
"#,
);

check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
mod some_module {
pub type Result<T> = core::result::Result<T, ()>;
}
use some_module::*;
fn foo() -> i3$02 {
return 42i32;
}
"#,
r#"
mod some_module {
pub type Result<T> = core::result::Result<T, ()>;
}
use some_module::*;
fn foo() -> Result<i32> {
return Ok(42i32);
}
"#,
);
}

#[test]
fn wrap_return_type_in_local_result_type_from_function_body() {
check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
fn foo() -> i3$02 {
type Result<T> = core::result::Result<T, ()>;
0
}
"#,
r#"
fn foo() -> Result<i32, ${0:_}> {
type Result<T> = core::result::Result<T, ()>;
Ok(0)
}
"#,
);
}

#[test]
fn wrap_return_type_in_local_result_type_already_using_alias() {
check_assist_not_applicable(
wrap_return_type_in_result,
r#"
//- minicore: result
pub type Result<T> = core::result::Result<T, ()>;
fn foo() -> Result<i3$02> {
return Ok(42i32);
}
"#,
);
}

#[test]
fn wrap_return_type_in_local_result_type_multiple_generics() {
check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result<T, E> = core::result::Result<T, E>;
fn foo() -> i3$02 {
0
}
"#,
r#"
type Result<T, E> = core::result::Result<T, E>;
fn foo() -> Result<i32, ${0:_}> {
Ok(0)
}
"#,
);

check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
fn foo() -> i3$02 {
0
}
"#,
r#"
type Result<T, E> = core::result::Result<Foo<T, E>, ()>;
fn foo() -> Result<i32, ${0:_}> {
Ok(0)
}
"#,
);

check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
fn foo() -> i3$02 {
0
}
"#,
r#"
type Result<'a, T, E> = core::result::Result<Foo<T, E>, &'a ()>;
fn foo() -> Result<'_, i32, ${0:_}> {
Ok(0)
}
"#,
);

check_assist(
wrap_return_type_in_result,
r#"
//- minicore: result
type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
fn foo() -> i3$02 {
0
}
"#,
r#"
type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
fn foo() -> Result<i32, ${0:_}> {
Ok(0)
}
"#,
);
}
}

0 comments on commit d534cc6

Please sign in to comment.