diff --git a/python/ray/serve/task_consumer.py b/python/ray/serve/task_consumer.py index b912b92b031c..b172eec24b0d 100644 --- a/python/ray/serve/task_consumer.py +++ b/python/ray/serve/task_consumer.py @@ -158,6 +158,9 @@ def __del__(self): if hasattr(target_cls, "__del__"): target_cls.__del__(self) + # Preserve the original class name + _TaskConsumerWrapper.__name__ = target_cls.__name__ + return _TaskConsumerWrapper return decorator diff --git a/python/ray/serve/tests/unit/test_task_consumer.py b/python/ray/serve/tests/unit/test_task_consumer.py index 2f1c72bdeb49..107a592ed071 100644 --- a/python/ray/serve/tests/unit/test_task_consumer.py +++ b/python/ray/serve/tests/unit/test_task_consumer.py @@ -5,6 +5,7 @@ import pytest +from ray.serve.api import deployment from ray.serve.schema import ( CeleryAdapterConfig, TaskProcessorAdapter, @@ -285,5 +286,19 @@ class MyConsumer: pass +def test_default_deployment_name_stays_same_with_task_consumer(config): + """Test that the default deployment name is the class name when using task_consumer with serve.deployment.""" + + @deployment + @task_consumer(task_processor_config=config) + class MyTaskConsumer: + @task_handler + def my_task(self): + pass + + # The deployment name should default to the class name + assert MyTaskConsumer.name == "MyTaskConsumer" + + if __name__ == "__main__": sys.exit(pytest.main(["-v", "-s", __file__]))