diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index 065f6a8a301d..0d9846f63b61 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -175,6 +175,34 @@ def format_actor_name(actor_name, *modifiers): return name +CLASS_WRAPPER_METADATA_ATTRS = ( + "__name__", + "__qualname__", + "__module__", + "__doc__", + "__annotations__", +) + + +def copy_class_metadata(wrapper_cls, target_cls) -> None: + """Copy common class-level metadata onto a wrapper class.""" + for attr in CLASS_WRAPPER_METADATA_ATTRS: + if attr == "__annotations__": + target_annotations = getattr(target_cls, "__annotations__", None) + if target_annotations: + merged_annotations = dict( + wrapper_cls.__dict__.get("__annotations__", {}) + ) + for key, value in target_annotations.items(): + merged_annotations.setdefault(key, value) + wrapper_cls.__annotations__ = merged_annotations + continue + + if hasattr(target_cls, attr): + setattr(wrapper_cls, attr, getattr(target_cls, attr)) + wrapper_cls.__wrapped__ = target_cls + + def ensure_serialization_context(): """Ensure the serialization addons on registered, even when Ray has not been started.""" diff --git a/python/ray/serve/api.py b/python/ray/serve/api.py index 2a771cd825f4..267e29d0ef70 100644 --- a/python/ray/serve/api.py +++ b/python/ray/serve/api.py @@ -33,6 +33,7 @@ from ray.serve._private.utils import ( DEFAULT, Default, + copy_class_metadata, ensure_serialization_context, extract_self_if_method_call, validate_route_prefix, @@ -309,7 +310,7 @@ async def __del__(self): else: cls.__del__(self) - ASGIIngressWrapper.__name__ = cls.__name__ + copy_class_metadata(ASGIIngressWrapper, cls) return ASGIIngressWrapper diff --git a/python/ray/serve/task_consumer.py b/python/ray/serve/task_consumer.py index b172eec24b0d..e057eaf966a0 100644 --- a/python/ray/serve/task_consumer.py +++ b/python/ray/serve/task_consumer.py @@ -9,6 +9,7 @@ SERVE_LOGGER_NAME, ) from ray.serve._private.task_consumer import TaskConsumerWrapper +from ray.serve._private.utils import copy_class_metadata from ray.serve.schema import ( TaskProcessorAdapter, TaskProcessorConfig, @@ -158,8 +159,7 @@ def __del__(self): if hasattr(target_cls, "__del__"): target_cls.__del__(self) - # Preserve the original class name - _TaskConsumerWrapper.__name__ = target_cls.__name__ + copy_class_metadata(_TaskConsumerWrapper, target_cls) return _TaskConsumerWrapper diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index efdf805c41cf..e8e9598d6f7f 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -73,6 +73,24 @@ async def __call__(self): return {"count": self.count} +def test_ingress_wrapper_preserves_metadata(): + app = FastAPI() + + class OriginalIngress: + """Sample ingress class.""" + + value: int + + wrapped_cls = serve.ingress(app)(OriginalIngress) + + assert wrapped_cls.__name__ == OriginalIngress.__name__ + assert wrapped_cls.__qualname__ == OriginalIngress.__qualname__ + assert wrapped_cls.__module__ == OriginalIngress.__module__ + assert wrapped_cls.__doc__ == OriginalIngress.__doc__ + assert wrapped_cls.__annotations__ == OriginalIngress.__annotations__ + assert getattr(wrapped_cls, "__wrapped__", None) is OriginalIngress + + class FakeRequestRouter(RequestRouter): async def choose_replicas( self, diff --git a/python/ray/serve/tests/unit/test_task_consumer.py b/python/ray/serve/tests/unit/test_task_consumer.py index 107a592ed071..2068eb19924c 100644 --- a/python/ray/serve/tests/unit/test_task_consumer.py +++ b/python/ray/serve/tests/unit/test_task_consumer.py @@ -300,5 +300,25 @@ def my_task(self): assert MyTaskConsumer.name == "MyTaskConsumer" +def test_task_consumer_preserves_metadata(config): + class OriginalConsumer: + """Docstring for a task consumer.""" + + value: int + + wrapped_cls = task_consumer(task_processor_config=config)(OriginalConsumer) + + assert wrapped_cls.__name__ == OriginalConsumer.__name__ + assert wrapped_cls.__qualname__ == OriginalConsumer.__qualname__ + assert wrapped_cls.__module__ == OriginalConsumer.__module__ + assert wrapped_cls.__doc__ == OriginalConsumer.__doc__ + assert ( + wrapped_cls.__annotations__["value"] + == OriginalConsumer.__annotations__["value"] + ) + assert wrapped_cls.__annotations__["_adapter"] is TaskProcessorAdapter + assert getattr(wrapped_cls, "__wrapped__", None) is OriginalConsumer + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__]))