Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions runpod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,46 @@
from .serverless.modules.rp_logger import RunPodLogger
from .version import __version__

__all__ = [
# API functions
"create_container_registry_auth",
"create_endpoint",
"create_pod",
"create_template",
"delete_container_registry_auth",
"get_endpoints",
"get_gpu",
"get_gpus",
"get_pod",
"get_pods",
"get_user",
"resume_pod",
"stop_pod",
"terminate_pod",
"update_container_registry_auth",
"update_endpoint_template",
"update_user_settings",
# Config functions
"check_credentials",
"get_credentials",
"set_credentials",
# Endpoint classes
"AsyncioEndpoint",
"AsyncioJob",
"Endpoint",
# Serverless module
"serverless",
# Logger class
"RunPodLogger",
# Version
"__version__",
# Module variables
"SSH_KEY_PATH",
"profile",
"api_key",
"endpoint_url_base"
]

# ------------------------------- Config Paths ------------------------------- #
SSH_KEY_PATH = os.path.expanduser("~/.runpod/ssh")

Expand Down
7 changes: 7 additions & 0 deletions runpod/endpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@
from .asyncio.asyncio_runner import Endpoint as AsyncioEndpoint
from .asyncio.asyncio_runner import Job as AsyncioJob
from .runner import Endpoint, Job

__all__ = [
"AsyncioEndpoint",
"AsyncioJob",
"Endpoint",
"Job"
]
5 changes: 5 additions & 0 deletions runpod/endpoint/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Asyncio endpoint for runpod."""

from .asyncio_runner import Endpoint, Job

__all__ = [
"Endpoint",
"Job"
]
7 changes: 7 additions & 0 deletions runpod/serverless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from . import worker
from .modules import rp_fastapi
from .modules.rp_logger import RunPodLogger
from .modules.rp_progress import progress_update

__all__ = [
"start",
"progress_update",
"runpod_version"
]

log = RunPodLogger()

Expand Down
6 changes: 6 additions & 0 deletions runpod/serverless/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@

from .rp_download import download_files_from_urls
from .rp_upload import upload_file_to_bucket, upload_in_memory_object

__all__ = [
"download_files_from_urls",
"upload_file_to_bucket",
"upload_in_memory_object"
]
75 changes: 75 additions & 0 deletions tests/test_endpoint/test_asyncio_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Tests for runpod.endpoint.asyncio.__init__ module exports."""

import inspect
import runpod.endpoint.asyncio


class TestEndpointAsyncioInit:
"""Test runpod.endpoint.asyncio module __all__ exports."""

def test_all_defined(self):
"""Test that __all__ is defined in the module."""
assert hasattr(runpod.endpoint.asyncio, '__all__')
assert isinstance(runpod.endpoint.asyncio.__all__, list)
assert len(runpod.endpoint.asyncio.__all__) > 0

def test_all_symbols_importable(self):
"""Test that all symbols in __all__ are actually importable."""
for symbol in runpod.endpoint.asyncio.__all__:
assert hasattr(runpod.endpoint.asyncio, symbol), f"Symbol '{symbol}' in __all__ but not found in module"

def test_expected_public_symbols(self):
"""Test that expected public symbols are in __all__."""
expected_symbols = {
'Endpoint',
'Job'
}
actual_symbols = set(runpod.endpoint.asyncio.__all__)
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"

def test_endpoint_classes_accessible(self):
"""Test that endpoint classes are accessible and are classes."""
endpoint_classes = ['Endpoint', 'Job']

for class_name in endpoint_classes:
assert class_name in runpod.endpoint.asyncio.__all__
assert hasattr(runpod.endpoint.asyncio, class_name)
assert inspect.isclass(getattr(runpod.endpoint.asyncio, class_name))

def test_asyncio_classes_are_different_from_parent_module(self):
"""Test that asyncio classes are different from the main endpoint classes."""
# Import the parent module classes for comparison
import runpod.endpoint

# The asyncio classes should be the same as AsyncioEndpoint/AsyncioJob from parent
assert runpod.endpoint.asyncio.Endpoint == runpod.endpoint.AsyncioEndpoint
assert runpod.endpoint.asyncio.Job == runpod.endpoint.AsyncioJob

# But different from the sync versions
assert runpod.endpoint.asyncio.Endpoint != runpod.endpoint.Endpoint
assert runpod.endpoint.asyncio.Job != runpod.endpoint.Job

def test_no_duplicate_symbols_in_all(self):
"""Test that __all__ contains no duplicate symbols."""
all_symbols = runpod.endpoint.asyncio.__all__
unique_symbols = set(all_symbols)
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"

def test_all_covers_public_api_only(self):
"""Test that __all__ contains only the intended public API."""
# Get all non-private attributes from the module
module_attrs = {name for name in dir(runpod.endpoint.asyncio)
if not name.startswith('_')}

# Filter out imported modules that shouldn't be public
expected_private_attrs = set() # No private imports in this module

public_attrs = module_attrs - expected_private_attrs
all_symbols = set(runpod.endpoint.asyncio.__all__)

# All symbols in __all__ should be actual public API
assert all_symbols.issubset(public_attrs), f"__all__ contains non-public symbols: {all_symbols - public_attrs}"

# Expected public API should be exactly what's in __all__
expected_public_api = {'Endpoint', 'Job'}
assert all_symbols == expected_public_api, f"Expected {expected_public_api}, got {all_symbols}"
69 changes: 69 additions & 0 deletions tests/test_endpoint/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Tests for runpod.endpoint.__init__ module exports."""

import inspect
import runpod.endpoint


class TestEndpointInit:
"""Test runpod.endpoint module __all__ exports."""

def test_all_defined(self):
"""Test that __all__ is defined in the module."""
assert hasattr(runpod.endpoint, '__all__')
assert isinstance(runpod.endpoint.__all__, list)
assert len(runpod.endpoint.__all__) > 0

def test_all_symbols_importable(self):
"""Test that all symbols in __all__ are actually importable."""
for symbol in runpod.endpoint.__all__:
assert hasattr(runpod.endpoint, symbol), f"Symbol '{symbol}' in __all__ but not found in module"

def test_expected_public_symbols(self):
"""Test that expected public symbols are in __all__."""
expected_symbols = {
'AsyncioEndpoint',
'AsyncioJob',
'Endpoint',
'Job'
}
actual_symbols = set(runpod.endpoint.__all__)
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"

def test_endpoint_classes_accessible(self):
"""Test that endpoint classes are accessible and are classes."""
endpoint_classes = ['AsyncioEndpoint', 'AsyncioJob', 'Endpoint', 'Job']

for class_name in endpoint_classes:
assert class_name in runpod.endpoint.__all__
assert hasattr(runpod.endpoint, class_name)
assert inspect.isclass(getattr(runpod.endpoint, class_name))

def test_asyncio_classes_distinct(self):
"""Test that asyncio classes are distinct from sync classes."""
assert runpod.endpoint.AsyncioEndpoint != runpod.endpoint.Endpoint
assert runpod.endpoint.AsyncioJob != runpod.endpoint.Job

def test_no_duplicate_symbols_in_all(self):
"""Test that __all__ contains no duplicate symbols."""
all_symbols = runpod.endpoint.__all__
unique_symbols = set(all_symbols)
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"

def test_all_covers_public_api_only(self):
"""Test that __all__ contains only the intended public API."""
# Get all non-private attributes from the module
module_attrs = {name for name in dir(runpod.endpoint)
if not name.startswith('_')}

# Filter out imported modules that shouldn't be public
expected_private_attrs = set() # No private imports in this module

public_attrs = module_attrs - expected_private_attrs
all_symbols = set(runpod.endpoint.__all__)

# All symbols in __all__ should be actual public API
assert all_symbols.issubset(public_attrs), f"__all__ contains non-public symbols: {all_symbols - public_attrs}"

# Expected public API should be exactly what's in __all__
expected_public_api = {'AsyncioEndpoint', 'AsyncioJob', 'Endpoint', 'Job'}
assert all_symbols == expected_public_api, f"Expected {expected_public_api}, got {all_symbols}"
118 changes: 118 additions & 0 deletions tests/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tests for runpod.__init__ module exports."""

import inspect
import runpod


class TestRunpodInit:
"""Test runpod module __all__ exports."""

def test_all_defined(self):
"""Test that __all__ is defined in the module."""
assert hasattr(runpod, '__all__')
assert isinstance(runpod.__all__, list)
assert len(runpod.__all__) > 0

def test_all_symbols_importable(self):
"""Test that all symbols in __all__ are actually importable."""
for symbol in runpod.__all__:
assert hasattr(runpod, symbol), f"Symbol '{symbol}' in __all__ but not found in module"

def test_api_functions_accessible(self):
"""Test that API functions are accessible and callable."""
api_functions = [
'create_container_registry_auth', 'create_endpoint', 'create_pod', 'create_template',
'delete_container_registry_auth', 'get_endpoints', 'get_gpu', 'get_gpus',
'get_pod', 'get_pods', 'get_user', 'resume_pod', 'stop_pod', 'terminate_pod',
'update_container_registry_auth', 'update_endpoint_template', 'update_user_settings'
]

for func_name in api_functions:
assert func_name in runpod.__all__
assert hasattr(runpod, func_name)
assert callable(getattr(runpod, func_name))

def test_config_functions_accessible(self):
"""Test that config functions are accessible and callable."""
config_functions = ['check_credentials', 'get_credentials', 'set_credentials']

for func_name in config_functions:
assert func_name in runpod.__all__
assert hasattr(runpod, func_name)
assert callable(getattr(runpod, func_name))

def test_endpoint_classes_accessible(self):
"""Test that endpoint classes are accessible."""
endpoint_classes = ['AsyncioEndpoint', 'AsyncioJob', 'Endpoint']

for class_name in endpoint_classes:
assert class_name in runpod.__all__
assert hasattr(runpod, class_name)
assert inspect.isclass(getattr(runpod, class_name))

def test_serverless_module_accessible(self):
"""Test that serverless module is accessible."""
assert 'serverless' in runpod.__all__
assert hasattr(runpod, 'serverless')
assert inspect.ismodule(runpod.serverless)

def test_logger_class_accessible(self):
"""Test that RunPodLogger class is accessible."""
assert 'RunPodLogger' in runpod.__all__
assert hasattr(runpod, 'RunPodLogger')
assert inspect.isclass(runpod.RunPodLogger)

def test_version_accessible(self):
"""Test that __version__ is accessible."""
assert '__version__' in runpod.__all__
assert hasattr(runpod, '__version__')
assert isinstance(runpod.__version__, str)

def test_module_variables_accessible(self):
"""Test that module variables are accessible."""
module_vars = ['SSH_KEY_PATH', 'profile', 'api_key', 'endpoint_url_base']

for var_name in module_vars:
assert var_name in runpod.__all__
assert hasattr(runpod, var_name)

def test_private_imports_not_exported(self):
"""Test that private imports are not in __all__."""
private_symbols = {
'logging', 'os', '_credentials'
}
all_symbols = set(runpod.__all__)

for private_symbol in private_symbols:
assert private_symbol not in all_symbols, f"Private symbol '{private_symbol}' should not be in __all__"

def test_all_covers_expected_public_api(self):
"""Test that __all__ contains the expected public API symbols."""
expected_symbols = {
# API functions
'create_container_registry_auth', 'create_endpoint', 'create_pod', 'create_template',
'delete_container_registry_auth', 'get_endpoints', 'get_gpu', 'get_gpus',
'get_pod', 'get_pods', 'get_user', 'resume_pod', 'stop_pod', 'terminate_pod',
'update_container_registry_auth', 'update_endpoint_template', 'update_user_settings',
# Config functions
'check_credentials', 'get_credentials', 'set_credentials',
# Endpoint classes
'AsyncioEndpoint', 'AsyncioJob', 'Endpoint',
# Serverless module
'serverless',
# Logger class
'RunPodLogger',
# Version
'__version__',
# Module variables
'SSH_KEY_PATH', 'profile', 'api_key', 'endpoint_url_base'
}

actual_symbols = set(runpod.__all__)
assert expected_symbols == actual_symbols, f"Expected {expected_symbols}, got {actual_symbols}"

def test_no_duplicate_symbols_in_all(self):
"""Test that __all__ contains no duplicate symbols."""
all_symbols = runpod.__all__
unique_symbols = set(all_symbols)
assert len(all_symbols) == len(unique_symbols), f"Duplicates found in __all__: {[x for x in all_symbols if all_symbols.count(x) > 1]}"
Loading
Loading