Skip to content

Commit 0a52c70

Browse files
authored
fix: proper public exports by __all__ (#437)
1 parent 88d2e4a commit 0a52c70

File tree

10 files changed

+514
-0
lines changed

10 files changed

+514
-0
lines changed

runpod/__init__.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,46 @@
3232
from .serverless.modules.rp_logger import RunPodLogger
3333
from .version import __version__
3434

35+
__all__ = [
36+
# API functions
37+
"create_container_registry_auth",
38+
"create_endpoint",
39+
"create_pod",
40+
"create_template",
41+
"delete_container_registry_auth",
42+
"get_endpoints",
43+
"get_gpu",
44+
"get_gpus",
45+
"get_pod",
46+
"get_pods",
47+
"get_user",
48+
"resume_pod",
49+
"stop_pod",
50+
"terminate_pod",
51+
"update_container_registry_auth",
52+
"update_endpoint_template",
53+
"update_user_settings",
54+
# Config functions
55+
"check_credentials",
56+
"get_credentials",
57+
"set_credentials",
58+
# Endpoint classes
59+
"AsyncioEndpoint",
60+
"AsyncioJob",
61+
"Endpoint",
62+
# Serverless module
63+
"serverless",
64+
# Logger class
65+
"RunPodLogger",
66+
# Version
67+
"__version__",
68+
# Module variables
69+
"SSH_KEY_PATH",
70+
"profile",
71+
"api_key",
72+
"endpoint_url_base"
73+
]
74+
3575
# ------------------------------- Config Paths ------------------------------- #
3676
SSH_KEY_PATH = os.path.expanduser("~/.runpod/ssh")
3777

runpod/endpoint/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,10 @@
33
from .asyncio.asyncio_runner import Endpoint as AsyncioEndpoint
44
from .asyncio.asyncio_runner import Job as AsyncioJob
55
from .runner import Endpoint, Job
6+
7+
__all__ = [
8+
"AsyncioEndpoint",
9+
"AsyncioJob",
10+
"Endpoint",
11+
"Job"
12+
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
"""Asyncio endpoint for runpod."""
22

33
from .asyncio_runner import Endpoint, Job
4+
5+
__all__ = [
6+
"Endpoint",
7+
"Job"
8+
]

runpod/serverless/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
from . import worker
1717
from .modules import rp_fastapi
1818
from .modules.rp_logger import RunPodLogger
19+
from .modules.rp_progress import progress_update
20+
21+
__all__ = [
22+
"start",
23+
"progress_update",
24+
"runpod_version"
25+
]
1926

2027
log = RunPodLogger()
2128

runpod/serverless/utils/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22

33
from .rp_download import download_files_from_urls
44
from .rp_upload import upload_file_to_bucket, upload_in_memory_object
5+
6+
__all__ = [
7+
"download_files_from_urls",
8+
"upload_file_to_bucket",
9+
"upload_in_memory_object"
10+
]
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Tests for runpod.endpoint.asyncio.__init__ module exports."""
2+
3+
import inspect
4+
import runpod.endpoint.asyncio
5+
6+
7+
class TestEndpointAsyncioInit:
8+
"""Test runpod.endpoint.asyncio module __all__ exports."""
9+
10+
def test_all_defined(self):
11+
"""Test that __all__ is defined in the module."""
12+
assert hasattr(runpod.endpoint.asyncio, '__all__')
13+
assert isinstance(runpod.endpoint.asyncio.__all__, list)
14+
assert len(runpod.endpoint.asyncio.__all__) > 0
15+
16+
def test_all_symbols_importable(self):
17+
"""Test that all symbols in __all__ are actually importable."""
18+
for symbol in runpod.endpoint.asyncio.__all__:
19+
assert hasattr(runpod.endpoint.asyncio, symbol), f"Symbol '{symbol}' in __all__ but not found in module"
20+
21+
def test_expected_public_symbols(self):
22+
"""Test that expected public symbols are in __all__."""
23+
expected_symbols = {
24+
'Endpoint',
25+
'Job'
26+
}
27+
actual_symbols = set(runpod.endpoint.asyncio.__all__)
28+
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"
29+
30+
def test_endpoint_classes_accessible(self):
31+
"""Test that endpoint classes are accessible and are classes."""
32+
endpoint_classes = ['Endpoint', 'Job']
33+
34+
for class_name in endpoint_classes:
35+
assert class_name in runpod.endpoint.asyncio.__all__
36+
assert hasattr(runpod.endpoint.asyncio, class_name)
37+
assert inspect.isclass(getattr(runpod.endpoint.asyncio, class_name))
38+
39+
def test_asyncio_classes_are_different_from_parent_module(self):
40+
"""Test that asyncio classes are different from the main endpoint classes."""
41+
# Import the parent module classes for comparison
42+
import runpod.endpoint
43+
44+
# The asyncio classes should be the same as AsyncioEndpoint/AsyncioJob from parent
45+
assert runpod.endpoint.asyncio.Endpoint == runpod.endpoint.AsyncioEndpoint
46+
assert runpod.endpoint.asyncio.Job == runpod.endpoint.AsyncioJob
47+
48+
# But different from the sync versions
49+
assert runpod.endpoint.asyncio.Endpoint != runpod.endpoint.Endpoint
50+
assert runpod.endpoint.asyncio.Job != runpod.endpoint.Job
51+
52+
def test_no_duplicate_symbols_in_all(self):
53+
"""Test that __all__ contains no duplicate symbols."""
54+
all_symbols = runpod.endpoint.asyncio.__all__
55+
unique_symbols = set(all_symbols)
56+
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"
57+
58+
def test_all_covers_public_api_only(self):
59+
"""Test that __all__ contains only the intended public API."""
60+
# Get all non-private attributes from the module
61+
module_attrs = {name for name in dir(runpod.endpoint.asyncio)
62+
if not name.startswith('_')}
63+
64+
# Filter out imported modules that shouldn't be public
65+
expected_private_attrs = set() # No private imports in this module
66+
67+
public_attrs = module_attrs - expected_private_attrs
68+
all_symbols = set(runpod.endpoint.asyncio.__all__)
69+
70+
# All symbols in __all__ should be actual public API
71+
assert all_symbols.issubset(public_attrs), f"__all__ contains non-public symbols: {all_symbols - public_attrs}"
72+
73+
# Expected public API should be exactly what's in __all__
74+
expected_public_api = {'Endpoint', 'Job'}
75+
assert all_symbols == expected_public_api, f"Expected {expected_public_api}, got {all_symbols}"

tests/test_endpoint/test_init.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Tests for runpod.endpoint.__init__ module exports."""
2+
3+
import inspect
4+
import runpod.endpoint
5+
6+
7+
class TestEndpointInit:
8+
"""Test runpod.endpoint module __all__ exports."""
9+
10+
def test_all_defined(self):
11+
"""Test that __all__ is defined in the module."""
12+
assert hasattr(runpod.endpoint, '__all__')
13+
assert isinstance(runpod.endpoint.__all__, list)
14+
assert len(runpod.endpoint.__all__) > 0
15+
16+
def test_all_symbols_importable(self):
17+
"""Test that all symbols in __all__ are actually importable."""
18+
for symbol in runpod.endpoint.__all__:
19+
assert hasattr(runpod.endpoint, symbol), f"Symbol '{symbol}' in __all__ but not found in module"
20+
21+
def test_expected_public_symbols(self):
22+
"""Test that expected public symbols are in __all__."""
23+
expected_symbols = {
24+
'AsyncioEndpoint',
25+
'AsyncioJob',
26+
'Endpoint',
27+
'Job'
28+
}
29+
actual_symbols = set(runpod.endpoint.__all__)
30+
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"
31+
32+
def test_endpoint_classes_accessible(self):
33+
"""Test that endpoint classes are accessible and are classes."""
34+
endpoint_classes = ['AsyncioEndpoint', 'AsyncioJob', 'Endpoint', 'Job']
35+
36+
for class_name in endpoint_classes:
37+
assert class_name in runpod.endpoint.__all__
38+
assert hasattr(runpod.endpoint, class_name)
39+
assert inspect.isclass(getattr(runpod.endpoint, class_name))
40+
41+
def test_asyncio_classes_distinct(self):
42+
"""Test that asyncio classes are distinct from sync classes."""
43+
assert runpod.endpoint.AsyncioEndpoint != runpod.endpoint.Endpoint
44+
assert runpod.endpoint.AsyncioJob != runpod.endpoint.Job
45+
46+
def test_no_duplicate_symbols_in_all(self):
47+
"""Test that __all__ contains no duplicate symbols."""
48+
all_symbols = runpod.endpoint.__all__
49+
unique_symbols = set(all_symbols)
50+
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"
51+
52+
def test_all_covers_public_api_only(self):
53+
"""Test that __all__ contains only the intended public API."""
54+
# Get all non-private attributes from the module
55+
module_attrs = {name for name in dir(runpod.endpoint)
56+
if not name.startswith('_')}
57+
58+
# Filter out imported modules that shouldn't be public
59+
expected_private_attrs = set() # No private imports in this module
60+
61+
public_attrs = module_attrs - expected_private_attrs
62+
all_symbols = set(runpod.endpoint.__all__)
63+
64+
# All symbols in __all__ should be actual public API
65+
assert all_symbols.issubset(public_attrs), f"__all__ contains non-public symbols: {all_symbols - public_attrs}"
66+
67+
# Expected public API should be exactly what's in __all__
68+
expected_public_api = {'AsyncioEndpoint', 'AsyncioJob', 'Endpoint', 'Job'}
69+
assert all_symbols == expected_public_api, f"Expected {expected_public_api}, got {all_symbols}"

tests/test_init.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Tests for runpod.__init__ module exports."""
2+
3+
import inspect
4+
import runpod
5+
6+
7+
class TestRunpodInit:
8+
"""Test runpod module __all__ exports."""
9+
10+
def test_all_defined(self):
11+
"""Test that __all__ is defined in the module."""
12+
assert hasattr(runpod, '__all__')
13+
assert isinstance(runpod.__all__, list)
14+
assert len(runpod.__all__) > 0
15+
16+
def test_all_symbols_importable(self):
17+
"""Test that all symbols in __all__ are actually importable."""
18+
for symbol in runpod.__all__:
19+
assert hasattr(runpod, symbol), f"Symbol '{symbol}' in __all__ but not found in module"
20+
21+
def test_api_functions_accessible(self):
22+
"""Test that API functions are accessible and callable."""
23+
api_functions = [
24+
'create_container_registry_auth', 'create_endpoint', 'create_pod', 'create_template',
25+
'delete_container_registry_auth', 'get_endpoints', 'get_gpu', 'get_gpus',
26+
'get_pod', 'get_pods', 'get_user', 'resume_pod', 'stop_pod', 'terminate_pod',
27+
'update_container_registry_auth', 'update_endpoint_template', 'update_user_settings'
28+
]
29+
30+
for func_name in api_functions:
31+
assert func_name in runpod.__all__
32+
assert hasattr(runpod, func_name)
33+
assert callable(getattr(runpod, func_name))
34+
35+
def test_config_functions_accessible(self):
36+
"""Test that config functions are accessible and callable."""
37+
config_functions = ['check_credentials', 'get_credentials', 'set_credentials']
38+
39+
for func_name in config_functions:
40+
assert func_name in runpod.__all__
41+
assert hasattr(runpod, func_name)
42+
assert callable(getattr(runpod, func_name))
43+
44+
def test_endpoint_classes_accessible(self):
45+
"""Test that endpoint classes are accessible."""
46+
endpoint_classes = ['AsyncioEndpoint', 'AsyncioJob', 'Endpoint']
47+
48+
for class_name in endpoint_classes:
49+
assert class_name in runpod.__all__
50+
assert hasattr(runpod, class_name)
51+
assert inspect.isclass(getattr(runpod, class_name))
52+
53+
def test_serverless_module_accessible(self):
54+
"""Test that serverless module is accessible."""
55+
assert 'serverless' in runpod.__all__
56+
assert hasattr(runpod, 'serverless')
57+
assert inspect.ismodule(runpod.serverless)
58+
59+
def test_logger_class_accessible(self):
60+
"""Test that RunPodLogger class is accessible."""
61+
assert 'RunPodLogger' in runpod.__all__
62+
assert hasattr(runpod, 'RunPodLogger')
63+
assert inspect.isclass(runpod.RunPodLogger)
64+
65+
def test_version_accessible(self):
66+
"""Test that __version__ is accessible."""
67+
assert '__version__' in runpod.__all__
68+
assert hasattr(runpod, '__version__')
69+
assert isinstance(runpod.__version__, str)
70+
71+
def test_module_variables_accessible(self):
72+
"""Test that module variables are accessible."""
73+
module_vars = ['SSH_KEY_PATH', 'profile', 'api_key', 'endpoint_url_base']
74+
75+
for var_name in module_vars:
76+
assert var_name in runpod.__all__
77+
assert hasattr(runpod, var_name)
78+
79+
def test_private_imports_not_exported(self):
80+
"""Test that private imports are not in __all__."""
81+
private_symbols = {
82+
'logging', 'os', '_credentials'
83+
}
84+
all_symbols = set(runpod.__all__)
85+
86+
for private_symbol in private_symbols:
87+
assert private_symbol not in all_symbols, f"Private symbol '{private_symbol}' should not be in __all__"
88+
89+
def test_all_covers_expected_public_api(self):
90+
"""Test that __all__ contains the expected public API symbols."""
91+
expected_symbols = {
92+
# API functions
93+
'create_container_registry_auth', 'create_endpoint', 'create_pod', 'create_template',
94+
'delete_container_registry_auth', 'get_endpoints', 'get_gpu', 'get_gpus',
95+
'get_pod', 'get_pods', 'get_user', 'resume_pod', 'stop_pod', 'terminate_pod',
96+
'update_container_registry_auth', 'update_endpoint_template', 'update_user_settings',
97+
# Config functions
98+
'check_credentials', 'get_credentials', 'set_credentials',
99+
# Endpoint classes
100+
'AsyncioEndpoint', 'AsyncioJob', 'Endpoint',
101+
# Serverless module
102+
'serverless',
103+
# Logger class
104+
'RunPodLogger',
105+
# Version
106+
'__version__',
107+
# Module variables
108+
'SSH_KEY_PATH', 'profile', 'api_key', 'endpoint_url_base'
109+
}
110+
111+
actual_symbols = set(runpod.__all__)
112+
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"
113+
114+
def test_no_duplicate_symbols_in_all(self):
115+
"""Test that __all__ contains no duplicate symbols."""
116+
all_symbols = runpod.__all__
117+
unique_symbols = set(all_symbols)
118+
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"

0 commit comments

Comments
 (0)