Skip to content

Commit

Permalink
Generate one fix per statement for flake8-type-checking rules (#4915)
Browse files Browse the repository at this point in the history
  • Loading branch information
charliermarsh authored Jun 8, 2023
1 parent 5235977 commit 4b78141
Show file tree
Hide file tree
Showing 11 changed files with 731 additions and 304 deletions.
25 changes: 11 additions & 14 deletions crates/ruff/src/checkers/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5069,22 +5069,19 @@ impl<'a> Checker<'a> {
.copied()
.collect()
};
for binding_id in scope.binding_ids() {
let binding = &self.semantic_model.bindings[binding_id];

flake8_type_checking::rules::runtime_import_in_type_checking_block(
self,
binding,
&mut diagnostics,
);
flake8_type_checking::rules::runtime_import_in_type_checking_block(
self,
scope,
&mut diagnostics,
);

flake8_type_checking::rules::typing_only_runtime_import(
self,
binding,
&runtime_imports,
&mut diagnostics,
);
}
flake8_type_checking::rules::typing_only_runtime_import(
self,
scope,
&runtime_imports,
&mut diagnostics,
);
}

if self.enabled(Rule::UnusedImport) {
Expand Down
16 changes: 8 additions & 8 deletions crates/ruff/src/importer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ impl<'a> Importer<'a> {
/// import statement.
pub(crate) fn runtime_import_edit(
&self,
import: &StmtImport,
import: &StmtImports,
at: TextSize,
) -> Result<RuntimeImportEdit> {
// Generate the modified import statement.
let content = autofix::codemods::retain_imports(
&[import.qualified_name],
&import.qualified_names,
import.stmt,
self.locator,
self.stylist,
Expand All @@ -114,13 +114,13 @@ impl<'a> Importer<'a> {
/// `TYPE_CHECKING` block.
pub(crate) fn typing_import_edit(
&self,
import: &StmtImport,
import: &StmtImports,
at: TextSize,
semantic_model: &SemanticModel,
) -> Result<TypingImportEdit> {
// Generate the modified import statement.
let content = autofix::codemods::retain_imports(
&[import.qualified_name],
&import.qualified_names,
import.stmt,
self.locator,
self.stylist,
Expand Down Expand Up @@ -442,12 +442,12 @@ impl<'a> ImportRequest<'a> {
}
}

/// An existing module or member import, located within an import statement.
pub(crate) struct StmtImport<'a> {
/// An existing list of module or member imports, located within an import statement.
pub(crate) struct StmtImports<'a> {
/// The import statement.
pub(crate) stmt: &'a Stmt,
/// The "full name" of the imported module or member.
pub(crate) qualified_name: &'a str,
/// The "qualified names" of the imported modules or members.
pub(crate) qualified_names: Vec<&'a str>,
}

/// The result of an [`Importer::get_or_import_symbol`] call.
Expand Down
42 changes: 42 additions & 0 deletions crates/ruff/src/rules/flake8_type_checking/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,48 @@ mod tests {
"#,
"import_from_type_checking_block"
)]
#[test_case(
r#"
from __future__ import annotations
from typing import TYPE_CHECKING
from pandas import (
DataFrame, # DataFrame
Series, # Series
)
def f(x: DataFrame, y: Series):
pass
"#,
"multiple_members"
)]
#[test_case(
r#"
from __future__ import annotations
from typing import TYPE_CHECKING
import os, sys
def f(x: os, y: sys):
pass
"#,
"multiple_modules_same_type"
)]
#[test_case(
r#"
from __future__ import annotations
from typing import TYPE_CHECKING
import os, pandas
def f(x: os, y: pandas):
pass
"#,
"multiple_modules_different_types"
)]
fn contents(contents: &str, snapshot: &str) {
let diagnostics = test_snippet(
contents,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use anyhow::Result;
use ruff_text_size::TextRange;
use rustc_hash::FxHashMap;

use ruff_diagnostics::{AutofixKind, Diagnostic, Fix, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_semantic::binding::Binding;
use ruff_python_semantic::node::NodeId;
use ruff_python_semantic::reference::ReferenceId;
use ruff_python_semantic::scope::Scope;

use crate::autofix;
use crate::checkers::ast::Checker;
use crate::importer::StmtImport;
use crate::registry::AsRule;
use crate::codes::Rule;
use crate::importer::StmtImports;

/// ## What it does
/// Checks for runtime imports defined in a type-checking block.
Expand Down Expand Up @@ -61,72 +67,172 @@ impl Violation for RuntimeImportInTypeCheckingBlock {
/// TCH004
pub(crate) fn runtime_import_in_type_checking_block(
checker: &Checker,
binding: &Binding,
scope: &Scope,
diagnostics: &mut Vec<Diagnostic>,
) {
let Some(qualified_name) = binding.qualified_name() else {
return;
};
// Collect all runtime imports by statement.
let mut errors_by_statement: FxHashMap<NodeId, Vec<Import>> = FxHashMap::default();
let mut ignores_by_statement: FxHashMap<NodeId, Vec<Import>> = FxHashMap::default();

let Some(reference_id) = binding.references.first() else {
return;
};
for binding_id in scope.binding_ids() {
let binding = &checker.semantic_model().bindings[binding_id];

if binding.context.is_typing()
&& binding.references().any(|reference_id| {
checker
.semantic_model()
.references
.resolve(reference_id)
.context()
.is_runtime()
})
let Some(qualified_name) = binding.qualified_name() else {
continue;
};

let Some(reference_id) = binding.references.first().copied() else {
continue;
};

if binding.context.is_typing()
&& binding.references().any(|reference_id| {
checker
.semantic_model()
.references
.resolve(reference_id)
.context()
.is_runtime()
})
{
let Some(stmt_id) = binding.source else {
continue;
};

let import = Import {
qualified_name,
reference_id,
trimmed_range: binding.trimmed_range(checker.semantic_model(), checker.locator),
parent_range: binding.parent_range(checker.semantic_model()),
};

if checker.rule_is_ignored(
Rule::RuntimeImportInTypeCheckingBlock,
import.trimmed_range.start(),
) || import.parent_range.map_or(false, |parent_range| {
checker
.rule_is_ignored(Rule::RuntimeImportInTypeCheckingBlock, parent_range.start())
}) {
ignores_by_statement
.entry(stmt_id)
.or_default()
.push(import);
} else {
errors_by_statement.entry(stmt_id).or_default().push(import);
}
}
}

// Generate a diagnostic for every import, but share a fix across all imports within the same
// statement (excluding those that are ignored).
for (stmt_id, imports) in errors_by_statement {
let fix = if checker.patch(Rule::RuntimeImportInTypeCheckingBlock) {
fix_imports(checker, stmt_id, &imports).ok()
} else {
None
};

for Import {
qualified_name,
trimmed_range,
parent_range,
..
} in imports
{
let mut diagnostic = Diagnostic::new(
RuntimeImportInTypeCheckingBlock {
qualified_name: qualified_name.to_string(),
},
trimmed_range,
);
if let Some(range) = parent_range {
diagnostic.set_parent(range.start());
}
if let Some(fix) = fix.as_ref() {
diagnostic.set_fix(fix.clone());
}
diagnostics.push(diagnostic);
}
}

// Separately, generate a diagnostic for every _ignored_ import, to ensure that the
// suppression comments aren't marked as unused.
for Import {
qualified_name,
trimmed_range,
parent_range,
..
} in ignores_by_statement.into_values().flatten()
{
let mut diagnostic = Diagnostic::new(
RuntimeImportInTypeCheckingBlock {
qualified_name: qualified_name.to_string(),
},
binding.trimmed_range(checker.semantic_model(), checker.locator),
trimmed_range,
);
if let Some(range) = binding.parent_range(checker.semantic_model()) {
if let Some(range) = parent_range {
diagnostic.set_parent(range.start());
}
diagnostics.push(diagnostic);
}
}

if checker.patch(diagnostic.kind.rule()) {
diagnostic.try_set_fix(|| {
// Step 1) Remove the import.
// SAFETY: All non-builtin bindings have a source.
let source = binding.source.unwrap();
let stmt = checker.semantic_model().stmts[source];
let parent = checker.semantic_model().stmts.parent(stmt);
let remove_import_edit = autofix::edits::remove_unused_imports(
std::iter::once(qualified_name),
stmt,
parent,
checker.locator,
checker.indexer,
checker.stylist,
)?;

// Step 2) Add the import to the top-level.
let reference = checker.semantic_model().references.resolve(*reference_id);
let add_import_edit = checker.importer.runtime_import_edit(
&StmtImport {
stmt,
qualified_name,
},
reference.range().start(),
)?;

Ok(
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
.isolate(checker.isolation(parent)),
)
});
}
/// A runtime-required import with its surrounding context.
struct Import<'a> {
/// The qualified name of the import (e.g., `typing.List` for `from typing import List`).
qualified_name: &'a str,
/// The first reference to the imported symbol.
reference_id: ReferenceId,
/// The trimmed range of the import (e.g., `List` in `from typing import List`).
trimmed_range: TextRange,
/// The range of the import's parent statement.
parent_range: Option<TextRange>,
}

if checker.enabled(diagnostic.kind.rule()) {
diagnostics.push(diagnostic);
}
}
/// Generate a [`Fix`] to remove runtime imports from a type-checking block.
fn fix_imports(checker: &Checker, stmt_id: NodeId, imports: &[Import]) -> Result<Fix> {
let stmt = checker.semantic_model().stmts[stmt_id];
let parent = checker.semantic_model().stmts.parent(stmt);
let qualified_names: Vec<&str> = imports
.iter()
.map(|Import { qualified_name, .. }| *qualified_name)
.collect();

// Find the first reference across all imports.
let at = imports
.iter()
.map(|Import { reference_id, .. }| {
checker
.semantic_model()
.references
.resolve(*reference_id)
.range()
.start()
})
.min()
.expect("Expected at least one import");

// Step 1) Remove the import.
let remove_import_edit = autofix::edits::remove_unused_imports(
qualified_names.iter().copied(),
stmt,
parent,
checker.locator,
checker.indexer,
checker.stylist,
)?;

// Step 2) Add the import to the top-level.
let add_import_edit = checker.importer.runtime_import_edit(
&StmtImports {
stmt,
qualified_names,
},
at,
)?;

Ok(
Fix::suggested_edits(remove_import_edit, add_import_edit.into_edits())
.isolate(checker.isolation(parent)),
)
}
Loading

0 comments on commit 4b78141

Please sign in to comment.