diff --git a/tests/test_rich_utils.py b/tests/test_rich_utils.py index fbcfdde4b8..d31dbafb5c 100644 --- a/tests/test_rich_utils.py +++ b/tests/test_rich_utils.py @@ -1,3 +1,5 @@ +import sys + import typer import typer.completion from typer.testing import CliRunner @@ -79,3 +81,21 @@ def main( assert "Hello Rick" in result.stdout assert "First: option_1_default" in result.stdout assert "Second: Morty" in result.stdout + + +def test_rich_markup_import_regression(): + # Remove rich.markup if it was imported by other tests + if "rich" in sys.modules: + rich_module = sys.modules["rich"] + if hasattr(rich_module, "markup"): + delattr(rich_module, "markup") + + app = typer.Typer(rich_markup_mode=None) + + @app.command() + def main(bar: str): + pass # pragma: no cover + + result = runner.invoke(app, ["--help"]) + assert "Usage" in result.stdout + assert "BAR" in result.stdout diff --git a/typer/core.py b/typer/core.py index 54e295d639..b197382029 100644 --- a/typer/core.py +++ b/typer/core.py @@ -372,7 +372,9 @@ def get_help_record(self, ctx: click.Context) -> Optional[Tuple[str, str]]: extra_str = f"[{extra_str}]" if rich is not None: # This is needed for when we want to export to HTML - extra_str = rich.markup.escape(extra_str).strip() + from . import rich_utils + + extra_str = rich_utils.escape_before_html_export(extra_str) help = f"{help} {extra_str}" if help else f"{extra_str}" return name, help @@ -583,7 +585,9 @@ def _write_opts(opts: Sequence[str]) -> str: extra_str = f"[{extra_str}]" if rich is not None: # This is needed for when we want to export to HTML - extra_str = rich.markup.escape(extra_str).strip() + from . import rich_utils + + extra_str = rich_utils.escape_before_html_export(extra_str) help = f"{help} {extra_str}" if help else f"{extra_str}" diff --git a/typer/main.py b/typer/main.py index 44a9725590..7ec9b8a0db 100644 --- a/typer/main.py +++ b/typer/main.py @@ -75,28 +75,19 @@ def except_hook( return typer_path = os.path.dirname(__file__) click_path = os.path.dirname(click.__file__) - supress_internal_dir_names = [typer_path, click_path] + internal_dir_names = [typer_path, click_path] exc = exc_value if rich: - from rich.traceback import Traceback - from . import rich_utils - rich_tb = Traceback.from_exception( - type(exc), - exc, - exc.__traceback__, - show_locals=exception_config.pretty_exceptions_show_locals, - suppress=supress_internal_dir_names, - width=rich_utils.MAX_WIDTH, - ) + rich_tb = rich_utils.get_traceback(exc, exception_config, internal_dir_names) console_stderr = rich_utils._get_rich_console(stderr=True) console_stderr.print(rich_tb) return tb_exc = traceback.TracebackException.from_exception(exc) stack: List[FrameSummary] = [] for frame in tb_exc.stack: - if any(frame.filename.startswith(path) for path in supress_internal_dir_names): + if any(frame.filename.startswith(path) for path in internal_dir_names): if not exception_config.pretty_exceptions_short: # Hide the line for internal libraries, Typer and Click stack.append( diff --git a/typer/rich_utils.py b/typer/rich_utils.py index 404e97503b..d4c3676aea 100644 --- a/typer/rich_utils.py +++ b/typer/rich_utils.py @@ -16,11 +16,14 @@ from rich.emoji import Emoji from rich.highlighter import RegexHighlighter from rich.markdown import Markdown +from rich.markup import escape from rich.padding import Padding from rich.panel import Panel from rich.table import Table from rich.text import Text from rich.theme import Theme +from rich.traceback import Traceback +from typer.models import DeveloperExceptionConfig if sys.version_info >= (3, 9): from typing import Literal @@ -727,6 +730,11 @@ def rich_abort_error() -> None: console.print(ABORTED_TEXT, style=STYLE_ABORTED) +def escape_before_html_export(input_text: str) -> str: + """Ensure that the input string can be used for HTML export.""" + return escape(input_text).strip() + + def rich_to_html(input_text: str) -> str: """Print the HTML version of a rich-formatted input string. @@ -744,3 +752,19 @@ def rich_render_text(text: str) -> str: """Remove rich tags and render a pure text representation""" console = _get_rich_console() return "".join(segment.text for segment in console.render(text)).rstrip("\n") + + +def get_traceback( + exc: BaseException, + exception_config: DeveloperExceptionConfig, + internal_dir_names: List[str], +) -> Traceback: + rich_tb = Traceback.from_exception( + type(exc), + exc, + exc.__traceback__, + show_locals=exception_config.pretty_exceptions_show_locals, + suppress=internal_dir_names, + width=MAX_WIDTH, + ) + return rich_tb