Skip to content
17 changes: 17 additions & 0 deletions python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,23 @@ def format_actor_name(actor_name, *modifiers):
return name


CLASS_WRAPPER_METADATA_ATTRS = (
"__name__",
"__qualname__",
"__module__",
"__doc__",
"__annotations__",
)


def copy_class_metadata(wrapper_cls, target_cls) -> None:
"""Copy common class-level metadata onto a wrapper class."""
for attr in CLASS_WRAPPER_METADATA_ATTRS:
if hasattr(target_cls, attr):
setattr(wrapper_cls, attr, getattr(target_cls, attr))
wrapper_cls.__wrapped__ = target_cls


def ensure_serialization_context():
"""Ensure the serialization addons on registered, even when Ray has not
been started."""
Expand Down
3 changes: 2 additions & 1 deletion python/ray/serve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ray.serve._private.utils import (
DEFAULT,
Default,
copy_class_metadata,
ensure_serialization_context,
extract_self_if_method_call,
validate_route_prefix,
Expand Down Expand Up @@ -309,7 +310,7 @@ async def __del__(self):
else:
cls.__del__(self)

ASGIIngressWrapper.__name__ = cls.__name__
copy_class_metadata(ASGIIngressWrapper, cls)

return ASGIIngressWrapper

Expand Down
4 changes: 2 additions & 2 deletions python/ray/serve/task_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SERVE_LOGGER_NAME,
)
from ray.serve._private.task_consumer import TaskConsumerWrapper
from ray.serve._private.utils import copy_class_metadata
from ray.serve.schema import (
TaskProcessorAdapter,
TaskProcessorConfig,
Expand Down Expand Up @@ -158,8 +159,7 @@ def __del__(self):
if hasattr(target_cls, "__del__"):
target_cls.__del__(self)

# Preserve the original class name
_TaskConsumerWrapper.__name__ = target_cls.__name__
copy_class_metadata(_TaskConsumerWrapper, target_cls)

return _TaskConsumerWrapper

Expand Down
18 changes: 18 additions & 0 deletions python/ray/serve/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ async def __call__(self):
return {"count": self.count}


def test_ingress_wrapper_preserves_metadata():
app = FastAPI()

class OriginalIngress:
"""Sample ingress class."""

value: int

wrapped_cls = serve.ingress(app)(OriginalIngress)

assert wrapped_cls.__name__ == OriginalIngress.__name__
assert wrapped_cls.__qualname__ == OriginalIngress.__qualname__
assert wrapped_cls.__module__ == OriginalIngress.__module__
assert wrapped_cls.__doc__ == OriginalIngress.__doc__
assert wrapped_cls.__annotations__ == OriginalIngress.__annotations__
assert getattr(wrapped_cls, "__wrapped__", None) is OriginalIngress


class FakeRequestRouter(RequestRouter):
async def choose_replicas(
self,
Expand Down
16 changes: 16 additions & 0 deletions python/ray/serve/tests/unit/test_task_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,5 +300,21 @@ def my_task(self):
assert MyTaskConsumer.name == "MyTaskConsumer"


def test_task_consumer_preserves_metadata(config):
class OriginalConsumer:
"""Docstring for a task consumer."""

value: int

wrapped_cls = task_consumer(task_processor_config=config)(OriginalConsumer)

assert wrapped_cls.__name__ == OriginalConsumer.__name__
assert wrapped_cls.__qualname__ == OriginalConsumer.__qualname__
assert wrapped_cls.__module__ == OriginalConsumer.__module__
assert wrapped_cls.__doc__ == OriginalConsumer.__doc__
assert wrapped_cls.__annotations__ == OriginalConsumer.__annotations__
assert getattr(wrapped_cls, "__wrapped__", None) is OriginalConsumer


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))