diff --git a/bump_pydantic/codemods/replace_imports.py b/bump_pydantic/codemods/replace_imports.py index 6030224..6f82a95 100644 --- a/bump_pydantic/codemods/replace_imports.py +++ b/bump_pydantic/codemods/replace_imports.py @@ -10,7 +10,9 @@ from __future__ import annotations +import sys from dataclasses import dataclass +from importlib.util import find_spec from typing import Sequence import libcst as cst @@ -35,6 +37,13 @@ } +def find_package_install(package_name: str) -> bool: + try: + return find_spec(package_name) is not None + except ModuleNotFoundError: + return False + + def resolve_module_parts(module_parts: list[str]) -> m.Attribute | m.Name: if len(module_parts) == 1: return m.Name(module_parts[0]) @@ -98,11 +107,20 @@ class ImportInfo: class ReplaceImportsCodemod(VisitorBasedCodemodCommand): @m.leave(IMPORT_MATCH) def leave_replace_import(self, _: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: + to_do_warnings = set() for import_info in IMPORT_INFOS: if m.matches(updated_node, import_info.import_from): aliases: Sequence[cst.ImportAlias] = updated_node.names # type: ignore # If multiple objects are imported in a single import statement, # we need to remove only the one we're replacing. + package_not_installed = not find_package_install(import_info.to_import_str[0]) + if package_not_installed: + import_info_part = import_info.to_import_str[0].split(".")[0] + to_do_warning = f" #todo: please install {import_info_part}\n" + if to_do_warning not in to_do_warnings: + sys.stdout.write(to_do_warning) + sys.stdout.flush() + to_do_warnings.add(to_do_warning) AddImportsVisitor.add_needed_import(self.context, *import_info.to_import_str) if len(updated_node.names) > 1: # type: ignore names = [alias for alias in aliases if alias.name.value != import_info.to_import_str[-1]] diff --git a/tests/unit/test_replace_imports.py b/tests/unit/test_replace_imports.py index 1a6e16b..a793d45 100644 --- a/tests/unit/test_replace_imports.py +++ b/tests/unit/test_replace_imports.py @@ -1,12 +1,33 @@ +import importlib +import io +import sys +from contextlib import contextmanager + import pytest from libcst.codemod import CodemodTest from bump_pydantic.codemods.replace_imports import ReplaceImportsCodemod +def is_package_installed(package_name): + try: + importlib.import_module(package_name) + return True + except ImportError: + return False + + class TestReplaceImportsCommand(CodemodTest): TRANSFORM = ReplaceImportsCodemod + @contextmanager + def capture_stdout(self): + new_stdout = io.StringIO() + old_stdout = sys.stdout + sys.stdout = new_stdout + yield new_stdout + sys.stdout = old_stdout + def test_base_settings(self) -> None: before = """ from pydantic import BaseSettings @@ -16,11 +37,22 @@ def test_base_settings(self) -> None: """ self.assertCodemod(before, after) + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_settings"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_settings" + assert captured.getvalue().strip() == expected_stdout + def test_noop_base_settings(self) -> None: code = """ from potato import BaseSettings """ - self.assertCodemod(code, code) + with self.capture_stdout() as captured: + self.assertCodemod(code, code) + assert captured.getvalue().strip() == "", "stdout is not empty as expected." @pytest.mark.xfail(reason="To be implemented.") def test_base_settings_as(self) -> None: @@ -39,7 +71,15 @@ def test_color(self) -> None: after = """ from pydantic_extra_types.color import Color """ - self.assertCodemod(before, after) + + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_extra_types"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_extra_types" + assert captured.getvalue().strip() == expected_stdout def test_color_full(self) -> None: before = """ @@ -48,13 +88,23 @@ def test_color_full(self) -> None: after = """ from pydantic_extra_types.color import Color """ - self.assertCodemod(before, after) + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_extra_types"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_extra_types" + assert captured.getvalue().strip() == expected_stdout def test_noop_color(self) -> None: code = """ from potato import Color """ self.assertCodemod(code, code) + with self.capture_stdout() as captured: + self.assertCodemod(code, code) + assert captured.getvalue().strip() == "", "stdout is not empty as expected." def test_payment_card_number(self) -> None: before = """ @@ -63,7 +113,14 @@ def test_payment_card_number(self) -> None: after = """ from pydantic_extra_types.payment import PaymentCardNumber """ - self.assertCodemod(before, after) + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_extra_types"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_extra_types" + assert captured.getvalue().strip() == expected_stdout def test_payment_card_brand(self) -> None: before = """ @@ -72,19 +129,30 @@ def test_payment_card_brand(self) -> None: after = """ from pydantic_extra_types.payment import PaymentCardBrand """ - self.assertCodemod(before, after) + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_extra_types"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_extra_types" + assert captured.getvalue().strip() == expected_stdout def test_noop_payment_card_number(self) -> None: code = """ from potato import PaymentCardNumber """ - self.assertCodemod(code, code) + with self.capture_stdout() as captured: + self.assertCodemod(code, code) + assert captured.getvalue().strip() == "", "stdout is not empty as expected." def test_noop_payment_card_brand(self) -> None: code = """ from potato import PaymentCardBrand """ - self.assertCodemod(code, code) + with self.capture_stdout() as captured: + self.assertCodemod(code, code) + assert captured.getvalue().strip() == "", "stdout is not empty as expected." def test_both_payment(self) -> None: before = """ @@ -93,4 +161,11 @@ def test_both_payment(self) -> None: after = """ from pydantic_extra_types.payment import PaymentCardBrand, PaymentCardNumber """ - self.assertCodemod(before, after) + with self.capture_stdout() as captured: + self.assertCodemod(before, after) + + if is_package_installed("pydantic_extra_types"): + assert captured.getvalue().strip() == "", "stdout is not empty as expected." + else: + expected_stdout = "#todo: please install pydantic_extra_types" + assert captured.getvalue().strip() == expected_stdout