diff --git a/crates/ruff/src/checkers/ast/analyze/statement.rs b/crates/ruff/src/checkers/ast/analyze/statement.rs index 81a36c87d6c76..0571dec1f6d76 100644 --- a/crates/ruff/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff/src/checkers/ast/analyze/statement.rs @@ -560,18 +560,26 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { } } if checker.enabled(Rule::BannedApi) { - flake8_tidy_imports::rules::name_or_parent_is_banned( + flake8_tidy_imports::rules::banned_api( checker, - &alias.name, - alias, + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchNameOrParent( + flake8_tidy_imports::matchers::MatchNameOrParent { + module: &alias.name, + }, + ), + &alias, ); } if checker.enabled(Rule::BannedModuleLevelImports) { - flake8_tidy_imports::rules::name_or_parent_is_banned_at_module_level( + flake8_tidy_imports::rules::banned_module_level_imports( checker, - &alias.name, - alias.range(), + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchNameOrParent( + flake8_tidy_imports::matchers::MatchNameOrParent { + module: &alias.name, + }, + ), + &alias, ); } @@ -729,16 +737,27 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if let Some(module) = helpers::resolve_imported_module_path(level, module, checker.module_path) { - flake8_tidy_imports::rules::name_or_parent_is_banned(checker, &module, stmt); + flake8_tidy_imports::rules::banned_api( + checker, + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchNameOrParent( + flake8_tidy_imports::matchers::MatchNameOrParent { module: &module }, + ), + &stmt, + ); for alias in names { if &alias.name == "*" { continue; } - flake8_tidy_imports::rules::name_is_banned( + flake8_tidy_imports::rules::banned_api( checker, - format!("{module}.{}", alias.name), - alias, + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchName( + flake8_tidy_imports::matchers::MatchName { + module: &module, + member: &alias.name, + }, + ), + &alias, ); } } @@ -747,20 +766,27 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { if let Some(module) = helpers::resolve_imported_module_path(level, module, checker.module_path) { - flake8_tidy_imports::rules::name_or_parent_is_banned_at_module_level( + flake8_tidy_imports::rules::banned_module_level_imports( checker, - &module, - stmt.range(), + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchNameOrParent( + flake8_tidy_imports::matchers::MatchNameOrParent { module: &module }, + ), + &stmt, ); for alias in names { if &alias.name == "*" { continue; } - flake8_tidy_imports::rules::name_is_banned_at_module_level( + flake8_tidy_imports::rules::banned_module_level_imports( checker, - &format!("{module}.{}", alias.name), - alias.range(), + &flake8_tidy_imports::matchers::NameMatchPolicy::MatchName( + flake8_tidy_imports::matchers::MatchName { + module: &module, + member: &alias.name, + }, + ), + &alias, ); } } diff --git a/crates/ruff/src/rules/flake8_tidy_imports/matchers.rs b/crates/ruff/src/rules/flake8_tidy_imports/matchers.rs new file mode 100644 index 0000000000000..a991280003299 --- /dev/null +++ b/crates/ruff/src/rules/flake8_tidy_imports/matchers.rs @@ -0,0 +1,75 @@ +/// Match an imported member against the ban policy. For example, given `from foo import bar`, +/// `foo` is the module and `bar` is the member. Performs an exact match. +#[derive(Debug)] +pub(crate) struct MatchName<'a> { + pub(crate) module: &'a str, + pub(crate) member: &'a str, +} + +impl MatchName<'_> { + fn is_match(&self, banned_module: &str) -> bool { + // Ex) Match banned `foo.bar` to import `foo.bar`, without allocating, assuming that + // `module` is `foo`, `member` is `bar`, and `banned_module` is `foo.bar`. + banned_module + .strip_prefix(self.module) + .and_then(|banned_module| banned_module.strip_prefix('.')) + .and_then(|banned_module| banned_module.strip_prefix(self.member)) + .is_some_and(str::is_empty) + } +} + +/// Match an imported module against the ban policy. For example, given `import foo.bar`, +/// `foo.bar` is the module. Matches against the module name or any of its parents. +#[derive(Debug)] +pub(crate) struct MatchNameOrParent<'a> { + pub(crate) module: &'a str, +} + +impl MatchNameOrParent<'_> { + fn is_match(&self, banned_module: &str) -> bool { + // Ex) Match banned `foo` to import `foo`. + if self.module == banned_module { + return true; + } + + // Ex) Match banned `foo` to import `foo.bar`. + if self + .module + .strip_prefix(banned_module) + .is_some_and(|suffix| suffix.starts_with('.')) + { + return true; + } + + false + } +} + +#[derive(Debug)] +pub(crate) enum NameMatchPolicy<'a> { + /// Only match an exact module name (e.g., given `import foo.bar`, only match `foo.bar`). + MatchName(MatchName<'a>), + /// Match an exact module name or any of its parents (e.g., given `import foo.bar`, match + /// `foo.bar` or `foo`). + MatchNameOrParent(MatchNameOrParent<'a>), +} + +impl NameMatchPolicy<'_> { + pub(crate) fn find<'a>(&self, banned_modules: impl Iterator) -> Option { + for banned_module in banned_modules { + match self { + NameMatchPolicy::MatchName(matcher) => { + if matcher.is_match(banned_module) { + return Some(banned_module.to_string()); + } + } + NameMatchPolicy::MatchNameOrParent(matcher) => { + if matcher.is_match(banned_module) { + return Some(banned_module.to_string()); + } + } + } + } + None + } +} diff --git a/crates/ruff/src/rules/flake8_tidy_imports/mod.rs b/crates/ruff/src/rules/flake8_tidy_imports/mod.rs index d136abbb1dd2e..302640f4f8df1 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/mod.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/mod.rs @@ -1,4 +1,5 @@ //! Rules from [flake8-tidy-imports](https://pypi.org/project/flake8-tidy-imports/). +pub(crate) mod matchers; pub mod options; pub(crate) mod rules; pub mod settings; diff --git a/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_api.rs b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_api.rs index e214802ded4f4..a2e8e8ad255b0 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_api.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_api.rs @@ -5,6 +5,7 @@ use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::call_path::from_qualified_name; use crate::checkers::ast::Checker; +use crate::rules::flake8_tidy_imports::matchers::NameMatchPolicy; /// ## What it does /// Checks for banned imports. @@ -38,45 +39,17 @@ impl Violation for BannedApi { } /// TID251 -pub(crate) fn name_is_banned(checker: &mut Checker, name: String, located: &T) -where - T: Ranged, -{ +pub(crate) fn banned_api(checker: &mut Checker, policy: &NameMatchPolicy, node: &T) { let banned_api = &checker.settings.flake8_tidy_imports.banned_api; - if let Some(ban) = banned_api.get(&name) { - checker.diagnostics.push(Diagnostic::new( - BannedApi { - name, - message: ban.msg.to_string(), - }, - located.range(), - )); - } -} - -/// TID251 -pub(crate) fn name_or_parent_is_banned(checker: &mut Checker, name: &str, located: &T) -where - T: Ranged, -{ - let banned_api = &checker.settings.flake8_tidy_imports.banned_api; - let mut name = name; - loop { - if let Some(ban) = banned_api.get(name) { + if let Some(banned_module) = policy.find(banned_api.keys().map(AsRef::as_ref)) { + if let Some(reason) = banned_api.get(&banned_module) { checker.diagnostics.push(Diagnostic::new( BannedApi { - name: name.to_string(), - message: ban.msg.to_string(), + name: banned_module, + message: reason.msg.to_string(), }, - located.range(), + node.range(), )); - return; - } - match name.rfind('.') { - Some(idx) => { - name = &name[..idx]; - } - None => return, } } } diff --git a/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs index 0efa1a3164231..283b8f90a4abe 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs @@ -1,8 +1,9 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_text_size::TextRange; +use ruff_python_ast::Ranged; use crate::checkers::ast::Checker; +use crate::rules::flake8_tidy_imports::matchers::NameMatchPolicy; /// ## What it does /// Checks for module-level imports that should instead be imported lazily @@ -52,60 +53,28 @@ impl Violation for BannedModuleLevelImports { } /// TID253 -pub(crate) fn name_is_banned_at_module_level( +pub(crate) fn banned_module_level_imports( checker: &mut Checker, - name: &str, - text_range: TextRange, -) { - banned_at_module_level_with_policy(checker, name, text_range, &NameMatchPolicy::ExactOnly); -} - -/// TID253 -pub(crate) fn name_or_parent_is_banned_at_module_level( - checker: &mut Checker, - name: &str, - text_range: TextRange, -) { - banned_at_module_level_with_policy(checker, name, text_range, &NameMatchPolicy::ExactOrParents); -} - -#[derive(Debug)] -enum NameMatchPolicy { - /// Only match an exact module name (e.g., given `import foo.bar`, only match `foo.bar`). - ExactOnly, - /// Match an exact module name or any of its parents (e.g., given `import foo.bar`, match - /// `foo.bar` or `foo`). - ExactOrParents, -} - -fn banned_at_module_level_with_policy( - checker: &mut Checker, - name: &str, - text_range: TextRange, policy: &NameMatchPolicy, + node: &T, ) { if !checker.semantic().at_top_level() { return; } - let banned_module_level_imports = &checker - .settings - .flake8_tidy_imports - .banned_module_level_imports; - for banned_module_name in banned_module_level_imports { - let name_is_banned = match policy { - NameMatchPolicy::ExactOnly => name == banned_module_name, - NameMatchPolicy::ExactOrParents => { - name == banned_module_name || name.starts_with(&format!("{banned_module_name}.")) - } - }; - if name_is_banned { - checker.diagnostics.push(Diagnostic::new( - BannedModuleLevelImports { - name: banned_module_name.to_string(), - }, - text_range, - )); - return; - } + + if let Some(banned_module) = policy.find( + checker + .settings + .flake8_tidy_imports + .banned_module_level_imports + .iter() + .map(AsRef::as_ref), + ) { + checker.diagnostics.push(Diagnostic::new( + BannedModuleLevelImports { + name: banned_module, + }, + node.range(), + )); } }