diff --git a/pyproject.toml b/pyproject.toml index a3fa94ea5c..b33ee6f921 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ requires-python = ">=3.8" dependencies = [ # intentionally loose. perhaps these should be vendored to not collide with user code? "attrs>=20.1,<24", + "cryptography>=46.0.3", "fastapi>=0.100,<0.119.0", "pydantic>=1.9,<3", "PyYAML", diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 6824cec5e2..26d95882ab 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -4,6 +4,9 @@ from .base_predictor import BasePredictor from .mimetypes_ext import install_mime_extensions +from .secret import ( + load_secret, +) from .server.scope import current_scope from .types import ( AsyncConcatenateIterator, @@ -34,5 +37,6 @@ "File", "Input", "Path", + "load_secret", "Secret", ] diff --git a/python/cog/secret.py b/python/cog/secret.py new file mode 100644 index 0000000000..e6bd91722d --- /dev/null +++ b/python/cog/secret.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import base64 +import os +from pathlib import Path + +import requests +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from dotenv import dotenv_values + +__all__ = [ + "load_secret", + "default_secret_provider", +] + + +def load_secret(name: str, secret_provider: SecretProvider | None) -> str: + if not secret_provider: + secret_provider = default_secret_provider + return secret_provider.get_secret(name) + + +class SecretProvider: + def __init__( + self, + cog_env_location: str = ".cog/.env", + cog_public_key_env_var: str = "COG_PUBLIC_KEY_LOCATION", + ) -> None: + self.env = {} + self.no_public_key = False + self.key = rsa.generate_private_key( + backend=default_backend(), + public_exponent=65537, + key_size=2048, + ) + self.secret_url: str | None = None + public_pem = self.key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + public_key_path_raw = os.getenv(cog_public_key_env_var) + if not public_key_path_raw: + self.no_public_key = True + return + public_key_path = Path(public_key_path_raw) + public_key_path.parent.mkdir(mode=0o700, exist_ok=True) + public_key_path.touch() + public_key_path.write_bytes(public_pem) + if not os.path.isfile(cog_env_location): + return + self.env = dotenv_values(cog_env_location) + + def get_secret(self, secret_name: str) -> str: + # Try to get the secret from the remote. Fall back to the local + # env file (local development only) + try: + if not self.secret_url: + raise ValueError("No secret URL passed") + if self.no_public_key: + raise ValueError("No public key for encryption") + raw_secret = os.getenv(secret_name) + if not raw_secret: + raise ValueError("No matching secret") + response = requests.post( + f"{self.secret_url}", + json={ + "value": raw_secret, + }, + ) + response.raise_for_status() + + plaintext_bytes = self.key.decrypt( + base64.b64decode(response.text), + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + + return plaintext_bytes.decode("utf-8") + except Exception: + return self.env.get(secret_name) or "" + + +default_secret_provider = SecretProvider() diff --git a/python/cog/server/http.py b/python/cog/server/http.py index 478a5e3a72..42642487a3 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -39,6 +39,7 @@ from ..json import upload_files from ..logging import setup_logging from ..mode import Mode +from ..secret import default_secret_provider from ..types import PYDANTIC_V2 try: @@ -126,12 +127,15 @@ def create_app( # pylint: disable=too-many-arguments,too-many-locals,too-many-s shutdown_event: Optional[threading.Event], # pylint: disable=redefined-outer-name app_threads: Optional[int] = None, upload_url: Optional[str] = None, + secrets_url: Optional[str] = None, mode: Mode = Mode.PREDICT, is_build: bool = False, await_explicit_shutdown: bool = False, # pylint: disable=redefined-outer-name ) -> MyFastAPI: started_at = datetime.now(tz=timezone.utc) + default_secret_provider.secret_url = secrets_url + @asynccontextmanager async def lifespan(app: MyFastAPI) -> AsyncGenerator[None, None]: # Startup code (was previously in @app.on_event("startup")) diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index f0ca578c53..de956c25ac 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -201,7 +201,10 @@ def result(self) -> T: class SetupTask(Task[SetupResult]): - def __init__(self, _clock: Optional[Callable[[], datetime]] = None) -> None: + def __init__( + self, + _clock: Optional[Callable[[], datetime]] = None, + ) -> None: log.info("starting setup") self._clock = _clock if self._clock is None: diff --git a/test-integration/test_integration/test_build.py b/test-integration/test_integration/test_build.py index e01fb6bb99..79fa336b8d 100644 --- a/test-integration/test_integration/test_build.py +++ b/test-integration/test_integration/test_build.py @@ -373,7 +373,7 @@ def test_pip_freeze(docker_image, cog_binary): ) assert ( pip_freeze - == "anyio==4.4.0\nattrs==23.2.0\ncertifi==2024.8.30\ncharset-normalizer==3.3.2\nclick==8.1.7\nexceptiongroup==1.2.2\nh11==0.14.0\nhttptools==0.6.1\nidna==3.8\npydantic==1.10.18\npython-dotenv==1.0.1\nPyYAML==6.0.2\nrequests==2.32.3\nsniffio==1.3.1\nstructlog==24.4.0\ntyping_extensions==4.12.2\nurllib3==2.2.2\nuvicorn==0.30.6\nuvloop==0.20.0\nwatchfiles==0.24.0\nwebsockets==13.0.1\n" + == "anyio==4.4.0\nattrs==23.2.0\ncertifi==2024.8.30\ncffi==2.0.0\ncharset-normalizer==3.3.2\nclick==8.1.7\ncryptography==46.0.3\nexceptiongroup==1.2.2\nh11==0.14.0\nhttptools==0.6.1\nidna==3.8\npycparser==2.23\npydantic==1.10.18\npython-dotenv==1.0.1\nPyYAML==6.0.2\nrequests==2.32.3\nsniffio==1.3.1\nstructlog==24.4.0\ntyping_extensions==4.15.0\nurllib3==2.2.2\nuvicorn==0.30.6\nuvloop==0.20.0\nwatchfiles==0.24.0\nwebsockets==13.0.1\n" )