Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
643 changes: 315 additions & 328 deletions .flox/env/manifest.lock

Large diffs are not rendered by default.

14 changes: 8 additions & 6 deletions applications/equitypricemodel/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM python:3.12.10-slim AS builder
FROM python:3.12.10-slim AS builder

COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /bin/uv

WORKDIR /app

Expand All @@ -10,7 +10,7 @@ COPY applications/equitypricemodel/ applications/equitypricemodel/

COPY libraries/python/ libraries/python/

RUN uv sync --no-dev
RUN uv sync --no-dev

FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS trainer

Expand All @@ -32,13 +32,15 @@ ENV CUDA=1

WORKDIR /app

COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /bin/uv

COPY --from=builder /app /app

ENV PYTHONPATH=/app/applications/equitypricemodel/src
ENV TRAINING_DATA_PATH=/app/training-data/filtered_tide_training_data.parquet
ENV MODEL_OUTPUT_PATH=/app/model-artifacts

ENTRYPOINT ["uv", "run", "--package", "equitypricemodel", "python", "applications/equitypricemodel/src/equitypricemodel/trainer.py"]
ENTRYPOINT ["uv", "run", "--package", "equitypricemodel", "python", "-m", "equitypricemodel.trainer"]

FROM python:3.12.10-slim AS server

Expand All @@ -50,7 +52,7 @@ ENV PYTHONPATH=/app/applications/equitypricemodel/src

WORKDIR /app

COPY --from=ghcr.io/astral-sh/uv:latest /uv /bin/uv
COPY --from=ghcr.io/astral-sh/uv:0.7.2 /uv /bin/uv

COPY --from=builder /app /app

Expand Down
2 changes: 1 addition & 1 deletion applications/equitypricemodel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
]

[dependency-groups]
dev = ["boto3-stubs[s3]>=1.38.0"]
dev = ["boto3-stubs[s3,ssm]>=1.38.0"]

[tool.uv.sources]
internal = { workspace = true }
Expand Down
64 changes: 50 additions & 14 deletions applications/equitypricemodel/src/equitypricemodel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
logger = structlog.get_logger()

DATAMANAGER_BASE_URL = os.getenv("FUND_DATAMANAGER_BASE_URL", "http://datamanager:8080")
MODEL_VERSION_SSM_PARAMETER = "/fund/equitypricemodel/model_version"


def find_latest_artifact_key(
Expand Down Expand Up @@ -156,10 +157,49 @@ def _safe_tar_filter(
temp_path.unlink(missing_ok=True)


def _resolve_artifact_key(
s3_client: "S3Client",
bucket: str,
artifact_path: str,
) -> str:
"""Resolve the model artifact S3 key using SSM Parameter Store."""
try:
ssm_client = boto3.client("ssm")
response = ssm_client.get_parameter(
Name=MODEL_VERSION_SSM_PARAMETER,
WithDecryption=True,
)
model_version = response["Parameter"]["Value"]
except ClientError:
logger.exception("SSM parameter not available, using default artifact path")
model_version = "latest"
Comment thread
chrisaddy marked this conversation as resolved.

if model_version != "latest":
logger.info("Using model version from SSM", model_version=model_version)
if model_version.endswith(".tar.gz"):
return model_version
return f"{artifact_path.rstrip('/')}/{model_version}/output/model.tar.gz"

if artifact_path.endswith(".tar.gz"):
return artifact_path

return find_latest_artifact_key(
s3_client=s3_client,
bucket=bucket,
prefix=artifact_path,
)


def cleanup_model_directory(model_directory: str) -> None:
if model_directory != "." and Path(model_directory).exists():
import shutil # noqa: PLC0415

shutil.rmtree(model_directory, ignore_errors=True)


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Load model artifacts from S3 at startup."""
import shutil # noqa: PLC0415

bucket = os.environ.get("AWS_S3_MODEL_ARTIFACTS_BUCKET_NAME")
artifact_path = os.environ.get("AWS_S3_MODEL_ARTIFACT_PATH", "artifacts/")
Expand All @@ -172,14 +212,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
extract_path = Path(model_directory)

try:
if artifact_path.endswith(".tar.gz"):
artifact_key = artifact_path
else:
artifact_key = find_latest_artifact_key(
s3_client=s3_client,
bucket=bucket,
prefix=artifact_path,
)
artifact_key = _resolve_artifact_key(
s3_client=s3_client,
bucket=bucket,
artifact_path=artifact_path,
)

download_and_extract_artifacts(
s3_client=s3_client,
Expand All @@ -188,21 +225,20 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
extract_path=extract_path,
)
except Exception:
logger.exception("failed_to_download_artifacts")
logger.exception("Failed to download artifacts")
raise

logger.info("loading_model", directory=model_directory)
logger.info("Loading model", directory=model_directory)
else:
logger.info("loading_model_from_local", directory=model_directory)
logger.info("Loading model from local", directory=model_directory)

app.state.model_directory = model_directory
app.state.tide_model = Model.load(directory_path=model_directory)
logger.info("model_loaded_successfully")

yield

if app.state.model_directory != "." and Path(app.state.model_directory).exists():
shutil.rmtree(app.state.model_directory, ignore_errors=True)
cleanup_model_directory(app.state.model_directory)


application = FastAPI(lifespan=lifespan)
Expand Down
Loading
Loading