diff --git a/mathics/eval/tracing.py b/mathics/eval/tracing.py index d10eb4186..7f7210d0d 100644 --- a/mathics/eval/tracing.py +++ b/mathics/eval/tracing.py @@ -114,7 +114,11 @@ def trace_evaluate(func: Callable) -> Callable: def wrapper(expr, evaluation) -> Any: from mathics.core.symbols import SymbolConstant - skip_call = False + # trace_evaluate_action allows for trace_evaluate_on_call() + # and trace_evaluate_return() to set the value of the + # expression instead of calling the function or replacing + # of return value of a Mathics3 function call. + trace_evaluate_action: Optional[Any] = None result = None was_boxing = evaluation.is_boxing if ( @@ -125,19 +129,33 @@ def wrapper(expr, evaluation) -> Any: # We may use boxing in print_evaluate_fn(). So turn off # boxing temporarily. evaluation.is_boxing = True - skip_call = trace_evaluate_on_call(expr, evaluation, "Evaluating", func) + trace_evaluate_action = trace_evaluate_on_call( + expr, evaluation, "Evaluating", func + ) evaluation.is_boxing = was_boxing - if not skip_call: + if trace_evaluate_action is None: result = func(expr, evaluation) if trace_evaluate_on_return is not None and not was_boxing: - trace_evaluate_on_return( + trace_evaluate_action = trace_evaluate_on_return( expr=result, evaluation=evaluation, status="Returning", fn=expr, orig_expr=expr, ) + if trace_evaluate_action is not None: + result = ( + (trace_evaluate_action, False) + if func.__name__ == "rewrite_apply_eval_step" + else trace_evaluate_action + ) evaluation.is_boxing = was_boxing + else: + result = ( + (trace_evaluate_action, False) + if func.__name__ == "rewrite_apply_eval_step" + else trace_evaluate_action + ) return result return wrapper diff --git a/test/builtin/test_trace.py b/test/builtin/test_trace.py index cc310c641..8a74e0b61 100644 --- a/test/builtin/test_trace.py +++ b/test/builtin/test_trace.py @@ -4,7 +4,7 @@ """ from inspect import isfunction, ismethod from test.helper import evaluate, session -from typing import Any, Callable +from typing import Any, Callable, Optional import pytest @@ -26,7 +26,7 @@ def test_TraceEvaluation(): def counting_print_evaluate( expr, evaluation: Evaluation, status: str, fn: Callable, orig_expr=None - ) -> bool: + ) -> Optional[Any]: """ A replacement for mathics.eval.tracing.print_evaluate() that counts the number of evaluation calls. @@ -36,7 +36,7 @@ def counting_print_evaluate( assert status in ("Evaluating", "Returning") if "cython" not in version_info: assert isfunction(fn), "Expecting 4th argument to be a function" - return False + return None try: # Set a small recursion limit, @@ -83,7 +83,7 @@ def empty_queue(): global event_queue event_queue = [] - def call_event_func(event: TraceEvent, fn: Callable, *args) -> bool: + def call_event_func(event: TraceEvent, fn: Callable, *args) -> Optional[Any]: """ Capture filtered calls in event_queue. """ @@ -92,7 +92,7 @@ def call_event_func(event: TraceEvent, fn: Callable, *args) -> bool: else: name = str(fn) event_queue.append(f"{event.name} call : {name}{args[:3]}") - return False + return None def return_event_func(event: TraceEvent, result: Any) -> Any: """