Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions runpod/api/ctl_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,24 +302,13 @@ def create_endpoint(
workers_min: int = 0,
workers_max: int = 3,
flashboot=False,
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
gpu_count: int = 1,
):
"""
Create an endpoint

:param name: the name of the endpoint
:param template_id: the id of the template to use for the endpoint
:param gpu_ids: the ids of the GPUs to use for the endpoint
:param network_volume_id: the id of the network volume to use for the endpoint
:param locations: the locations to use for the endpoint
:param idle_timeout: the idle timeout for the endpoint
:param scaler_type: the scaler type for the endpoint
:param scaler_value: the scaler value for the endpoint
:param workers_min: the minimum number of workers for the endpoint
:param workers_max: the maximum number of workers for the endpoint

:example:

>>> endpoint_id = runpod.create_endpoint("test", "template_id")
:param allowed_cuda_versions: Comma-separated string of allowed CUDA versions (e.g., "12.4,12.5").
"""
raw_response = run_graphql_query(
endpoint_mutations.generate_endpoint_mutation(
Expand All @@ -334,6 +323,8 @@ def create_endpoint(
workers_min,
workers_max,
flashboot,
allowed_cuda_versions,
gpu_count
)
)

Expand Down
11 changes: 11 additions & 0 deletions runpod/api/mutations/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def generate_endpoint_mutation(
workers_min: int = 0,
workers_max: int = 3,
flashboot=False,
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
gpu_count: int = None,
):
"""Generate a string for a GraphQL mutation to create a new endpoint."""
input_fields = []
Expand Down Expand Up @@ -44,6 +46,12 @@ def generate_endpoint_mutation(
input_fields.append(f"workersMin: {workers_min}")
input_fields.append(f"workersMax: {workers_max}")

if allowed_cuda_versions is not None:
input_fields.append(f'allowedCudaVersions: "{allowed_cuda_versions}"')

if gpu_count is not None:
input_fields.append(f"gpuCount: {gpu_count}")

# Format the input fields into a string
input_fields_string = ", ".join(input_fields)

Expand All @@ -65,11 +73,14 @@ def generate_endpoint_mutation(
scalerValue
workersMin
workersMax
allowedCudaVersions
gpuCount
}}
}}
"""



def update_endpoint_template_mutation(endpoint_id: str, template_id: str):
"""Generate a string for a GraphQL mutation to update an existing endpoint's template."""
input_fields = []
Expand Down
Loading