|
1 | 1 | use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
|
2 | 2 | use ruff_macros::{derive_message_formats, violation};
|
3 |
| -use ruff_python_ast::Expr; |
4 |
| -use ruff_python_semantic::Modules; |
| 3 | +use ruff_python_ast::name::{QualifiedName, QualifiedNameBuilder}; |
| 4 | +use ruff_python_ast::statement_visitor::StatementVisitor; |
| 5 | +use ruff_python_ast::visitor::Visitor; |
| 6 | +use ruff_python_ast::{self as ast, Expr}; |
| 7 | +use ruff_python_semantic::{Exceptions, Modules, SemanticModel}; |
5 | 8 | use ruff_text_size::Ranged;
|
6 | 9 |
|
7 | 10 | use crate::checkers::ast::Checker;
|
@@ -665,6 +668,10 @@ pub(crate) fn numpy_2_0_deprecation(checker: &mut Checker, expr: &Expr) {
|
665 | 668 | _ => return,
|
666 | 669 | };
|
667 | 670 |
|
| 671 | + if is_guarded_by_try_except(expr, &replacement, semantic) { |
| 672 | + return; |
| 673 | + } |
| 674 | + |
668 | 675 | let mut diagnostic = Diagnostic::new(
|
669 | 676 | Numpy2Deprecation {
|
670 | 677 | existing: replacement.existing.to_string(),
|
@@ -701,3 +708,233 @@ pub(crate) fn numpy_2_0_deprecation(checker: &mut Checker, expr: &Expr) {
|
701 | 708 | };
|
702 | 709 | checker.diagnostics.push(diagnostic);
|
703 | 710 | }
|
| 711 | + |
| 712 | +/// Ignore attempts to access a `numpy` member via its deprecated name |
| 713 | +/// if the access takes place in an `except` block that provides compatibility |
| 714 | +/// with older numpy versions. |
| 715 | +/// |
| 716 | +/// For attribute accesses (e.g. `np.ComplexWarning`), we only ignore the violation |
| 717 | +/// if it's inside an `except AttributeError` block, and the member is accessed |
| 718 | +/// through its non-deprecated name in the associated `try` block. |
| 719 | +/// |
| 720 | +/// For uses of the `numpy` member where it's simply an `ExprName` node, |
| 721 | +/// we check to see how the `numpy` member was bound. If it was bound via a |
| 722 | +/// `from numpy import foo` statement, we check to see if that import statement |
| 723 | +/// took place inside an `except ImportError` or `except ModuleNotFoundError` block. |
| 724 | +/// If so, and if the `numpy` member was imported through its non-deprecated name |
| 725 | +/// in the associated try block, we ignore the violation in the same way. |
| 726 | +/// |
| 727 | +/// Examples: |
| 728 | +/// |
| 729 | +/// ```py |
| 730 | +/// import numpy as np |
| 731 | +/// |
| 732 | +/// try: |
| 733 | +/// np.all([True, True]) |
| 734 | +/// except AttributeError: |
| 735 | +/// np.alltrue([True, True]) # Okay |
| 736 | +/// |
| 737 | +/// try: |
| 738 | +/// from numpy.exceptions import ComplexWarning |
| 739 | +/// except ImportError: |
| 740 | +/// from numpy import ComplexWarning |
| 741 | +/// |
| 742 | +/// x = ComplexWarning() # Okay |
| 743 | +/// ``` |
| 744 | +fn is_guarded_by_try_except( |
| 745 | + expr: &Expr, |
| 746 | + replacement: &Replacement, |
| 747 | + semantic: &SemanticModel, |
| 748 | +) -> bool { |
| 749 | + match expr { |
| 750 | + Expr::Attribute(_) => { |
| 751 | + if !semantic.in_exception_handler() { |
| 752 | + return false; |
| 753 | + } |
| 754 | + let Some(try_node) = semantic |
| 755 | + .current_statements() |
| 756 | + .find_map(|stmt| stmt.as_try_stmt()) |
| 757 | + else { |
| 758 | + return false; |
| 759 | + }; |
| 760 | + let suspended_exceptions = Exceptions::from_try_stmt(try_node, semantic); |
| 761 | + if !suspended_exceptions.contains(Exceptions::ATTRIBUTE_ERROR) { |
| 762 | + return false; |
| 763 | + } |
| 764 | + try_block_contains_undeprecated_attribute(try_node, &replacement.details, semantic) |
| 765 | + } |
| 766 | + Expr::Name(ast::ExprName { id, .. }) => { |
| 767 | + let Some(binding_id) = semantic.lookup_symbol(id.as_str()) else { |
| 768 | + return false; |
| 769 | + }; |
| 770 | + let binding = semantic.binding(binding_id); |
| 771 | + if !binding.is_external() { |
| 772 | + return false; |
| 773 | + } |
| 774 | + if !binding.in_exception_handler() { |
| 775 | + return false; |
| 776 | + } |
| 777 | + let Some(try_node) = binding.source.and_then(|import_id| { |
| 778 | + semantic |
| 779 | + .statements(import_id) |
| 780 | + .find_map(|stmt| stmt.as_try_stmt()) |
| 781 | + }) else { |
| 782 | + return false; |
| 783 | + }; |
| 784 | + let suspended_exceptions = Exceptions::from_try_stmt(try_node, semantic); |
| 785 | + if !suspended_exceptions |
| 786 | + .intersects(Exceptions::IMPORT_ERROR | Exceptions::MODULE_NOT_FOUND_ERROR) |
| 787 | + { |
| 788 | + return false; |
| 789 | + } |
| 790 | + try_block_contains_undeprecated_import(try_node, &replacement.details) |
| 791 | + } |
| 792 | + _ => false, |
| 793 | + } |
| 794 | +} |
| 795 | + |
| 796 | +/// Given an [`ast::StmtTry`] node, does the `try` branch of that node |
| 797 | +/// contain any [`ast::ExprAttribute`] nodes that indicate the numpy |
| 798 | +/// member is being accessed from the non-deprecated location? |
| 799 | +fn try_block_contains_undeprecated_attribute( |
| 800 | + try_node: &ast::StmtTry, |
| 801 | + replacement_details: &Details, |
| 802 | + semantic: &SemanticModel, |
| 803 | +) -> bool { |
| 804 | + let Details::AutoImport { |
| 805 | + path, |
| 806 | + name, |
| 807 | + compatibility: _, |
| 808 | + } = replacement_details |
| 809 | + else { |
| 810 | + return false; |
| 811 | + }; |
| 812 | + let undeprecated_qualified_name = { |
| 813 | + let mut builder = QualifiedNameBuilder::default(); |
| 814 | + for part in path.split('.') { |
| 815 | + builder.push(part); |
| 816 | + } |
| 817 | + builder.push(name); |
| 818 | + builder.build() |
| 819 | + }; |
| 820 | + let mut attribute_searcher = AttributeSearcher::new(undeprecated_qualified_name, semantic); |
| 821 | + attribute_searcher.visit_body(&try_node.body); |
| 822 | + attribute_searcher.found_attribute |
| 823 | +} |
| 824 | + |
| 825 | +/// AST visitor that searches an AST tree for [`ast::ExprAttribute`] nodes |
| 826 | +/// that match a certain [`QualifiedName`]. |
| 827 | +struct AttributeSearcher<'a> { |
| 828 | + attribute_to_find: QualifiedName<'a>, |
| 829 | + semantic: &'a SemanticModel<'a>, |
| 830 | + found_attribute: bool, |
| 831 | +} |
| 832 | + |
| 833 | +impl<'a> AttributeSearcher<'a> { |
| 834 | + fn new(attribute_to_find: QualifiedName<'a>, semantic: &'a SemanticModel<'a>) -> Self { |
| 835 | + Self { |
| 836 | + attribute_to_find, |
| 837 | + semantic, |
| 838 | + found_attribute: false, |
| 839 | + } |
| 840 | + } |
| 841 | +} |
| 842 | + |
| 843 | +impl Visitor<'_> for AttributeSearcher<'_> { |
| 844 | + fn visit_expr(&mut self, expr: &'_ Expr) { |
| 845 | + if self.found_attribute { |
| 846 | + return; |
| 847 | + } |
| 848 | + if expr.is_attribute_expr() |
| 849 | + && self |
| 850 | + .semantic |
| 851 | + .resolve_qualified_name(expr) |
| 852 | + .is_some_and(|qualified_name| qualified_name == self.attribute_to_find) |
| 853 | + { |
| 854 | + self.found_attribute = true; |
| 855 | + return; |
| 856 | + } |
| 857 | + ast::visitor::walk_expr(self, expr); |
| 858 | + } |
| 859 | + |
| 860 | + fn visit_stmt(&mut self, stmt: &ruff_python_ast::Stmt) { |
| 861 | + if !self.found_attribute { |
| 862 | + ast::visitor::walk_stmt(self, stmt); |
| 863 | + } |
| 864 | + } |
| 865 | + |
| 866 | + fn visit_body(&mut self, body: &[ruff_python_ast::Stmt]) { |
| 867 | + for stmt in body { |
| 868 | + self.visit_stmt(stmt); |
| 869 | + if self.found_attribute { |
| 870 | + return; |
| 871 | + } |
| 872 | + } |
| 873 | + } |
| 874 | +} |
| 875 | + |
| 876 | +/// Given an [`ast::StmtTry`] node, does the `try` branch of that node |
| 877 | +/// contain any [`ast::StmtImportFrom`] nodes that indicate the numpy |
| 878 | +/// member is being imported from the non-deprecated location? |
| 879 | +fn try_block_contains_undeprecated_import( |
| 880 | + try_node: &ast::StmtTry, |
| 881 | + replacement_details: &Details, |
| 882 | +) -> bool { |
| 883 | + let Details::AutoImport { |
| 884 | + path, |
| 885 | + name, |
| 886 | + compatibility: _, |
| 887 | + } = replacement_details |
| 888 | + else { |
| 889 | + return false; |
| 890 | + }; |
| 891 | + let mut import_searcher = ImportSearcher::new(path, name); |
| 892 | + import_searcher.visit_body(&try_node.body); |
| 893 | + import_searcher.found_import |
| 894 | +} |
| 895 | + |
| 896 | +/// AST visitor that searches an AST tree for [`ast::StmtImportFrom`] nodes |
| 897 | +/// that match a certain [`QualifiedName`]. |
| 898 | +struct ImportSearcher<'a> { |
| 899 | + module: &'a str, |
| 900 | + name: &'a str, |
| 901 | + found_import: bool, |
| 902 | +} |
| 903 | + |
| 904 | +impl<'a> ImportSearcher<'a> { |
| 905 | + fn new(module: &'a str, name: &'a str) -> Self { |
| 906 | + Self { |
| 907 | + module, |
| 908 | + name, |
| 909 | + found_import: false, |
| 910 | + } |
| 911 | + } |
| 912 | +} |
| 913 | + |
| 914 | +impl StatementVisitor<'_> for ImportSearcher<'_> { |
| 915 | + fn visit_stmt(&mut self, stmt: &ast::Stmt) { |
| 916 | + if self.found_import { |
| 917 | + return; |
| 918 | + } |
| 919 | + if let ast::Stmt::ImportFrom(ast::StmtImportFrom { module, names, .. }) = stmt { |
| 920 | + if module.as_ref().is_some_and(|module| module == self.module) |
| 921 | + && names |
| 922 | + .iter() |
| 923 | + .any(|ast::Alias { name, .. }| name == self.name) |
| 924 | + { |
| 925 | + self.found_import = true; |
| 926 | + return; |
| 927 | + } |
| 928 | + } |
| 929 | + ast::statement_visitor::walk_stmt(self, stmt); |
| 930 | + } |
| 931 | + |
| 932 | + fn visit_body(&mut self, body: &[ruff_python_ast::Stmt]) { |
| 933 | + for stmt in body { |
| 934 | + self.visit_stmt(stmt); |
| 935 | + if self.found_import { |
| 936 | + return; |
| 937 | + } |
| 938 | + } |
| 939 | + } |
| 940 | +} |
0 commit comments