Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion applications/equitypricemodel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
]

[dependency-groups]
dev = ["boto3-stubs[s3]>=1.38.0"]
dev = ["boto3-stubs[s3,ssm]>=1.38.0"]

[tool.uv.sources]
internal = { workspace = true }
Expand Down
48 changes: 46 additions & 2 deletions applications/equitypricemodel/src/equitypricemodel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
logger = structlog.get_logger()

DATAMANAGER_BASE_URL = os.getenv("FUND_DATAMANAGER_BASE_URL", "http://datamanager:8080")
MODEL_VERSION_SSM_PARAMETER = "/fund/equitypricemodel/model_version"


def find_latest_artifact_key(
Expand Down Expand Up @@ -156,6 +157,49 @@ def _safe_tar_filter(
temp_path.unlink(missing_ok=True)


def _resolve_artifact_key(
s3_client: "S3Client",
bucket: str,
artifact_path: str,
) -> str:
"""Resolve the S3 artifact key using SSM version pinning or latest artifact."""
normalized_artifact_prefix = artifact_path.rstrip("/") + "/"
model_version = ""
try:
ssm_client = boto3.client("ssm")
response = ssm_client.get_parameter(Name=MODEL_VERSION_SSM_PARAMETER)
model_version = response["Parameter"]["Value"].strip().strip("/")
if model_version and model_version != "latest":
logger.info(
"Resolved artifact key from pinned model version",
version=model_version,
)
Comment thread
forstmeier marked this conversation as resolved.
else:
Comment thread
forstmeier marked this conversation as resolved.
model_version = ""
except ClientError as error:
error_code = error.response["Error"]["Code"]
if error_code == "ParameterNotFound":
logger.info(
"SSM parameter not found, falling back to latest model artifact",
parameter=MODEL_VERSION_SSM_PARAMETER,
)
else:
logger.exception(
"SSM parameter read failed",
parameter=MODEL_VERSION_SSM_PARAMETER,
error_code=error_code,
)

if model_version:
return f"{normalized_artifact_prefix}{model_version}/output/model.tar.gz"

return find_latest_artifact_key(
s3_client=s3_client,
bucket=bucket,
prefix=normalized_artifact_prefix,
)


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Load model artifacts from S3 at startup."""
Expand All @@ -175,10 +219,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if artifact_path.endswith(".tar.gz"):
artifact_key = artifact_path
else:
artifact_key = find_latest_artifact_key(
artifact_key = _resolve_artifact_key(
s3_client=s3_client,
bucket=bucket,
prefix=artifact_path,
artifact_path=artifact_path,
)

download_and_extract_artifacts(
Expand Down
26 changes: 10 additions & 16 deletions applications/equitypricemodel/src/equitypricemodel/tide_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,9 @@ def _get_prediction_data(

def get_dimensions(self) -> dict[str, int]:
return {
"encoder_categorical_features": len(self.categorical_columns),
"encoder_continuous_features": len(self.continuous_columns),
"decoder_categorical_features": len(self.categorical_columns),
"decoder_continuous_features": 0, # not using decoder_continuous_features for now # noqa: E501
"past_continuous_features": len(self.continuous_columns),
"past_categorical_features": len(self.categorical_columns),
"static_categorical_features": len(self.static_categorical_columns),
"static_continuous_features": 0, # not using static_continuous_features for now # noqa: E501
}

def get_batches( # noqa: C901
Expand Down Expand Up @@ -417,10 +414,10 @@ def get_batches( # noqa: C901
# Use numpy slicing (much faster than DataFrame slicing)
for i in window_indices:
sample = {
"encoder_categorical": cat_array[i : i + input_length].copy(),
"encoder_continuous": cont_array[i : i + input_length].copy(),
"decoder_categorical": cat_array[
i + input_length : i + input_length + output_length
"past_continuous": cont_array[i : i + input_length].copy(),
# Calendar features are known for both lookback and forecast windows
"past_categorical": cat_array[
i : i + input_length + output_length
].copy(),
Comment thread
forstmeier marked this conversation as resolved.
Comment thread
forstmeier marked this conversation as resolved.
"static_categorical": static_array.copy(),
}
Expand All @@ -441,14 +438,11 @@ def get_batches( # noqa: C901
batch_samples = samples[i : i + batch_size]

batch = {
"encoder_categorical_features": Tensor(
np.stack([s["encoder_categorical"] for s in batch_samples])
),
"encoder_continuous_features": Tensor(
np.stack([s["encoder_continuous"] for s in batch_samples])
"past_continuous_features": Tensor(
np.stack([s["past_continuous"] for s in batch_samples])
),
"decoder_categorical_features": Tensor(
np.stack([s["decoder_categorical"] for s in batch_samples])
"past_categorical_features": Tensor(
np.stack([s["past_categorical"] for s in batch_samples])
),
"static_categorical_features": Tensor(
np.stack([s["static_categorical"] for s in batch_samples])
Expand Down
131 changes: 89 additions & 42 deletions applications/equitypricemodel/src/equitypricemodel/tide_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import random
from pathlib import Path
from typing import cast

import numpy as np
Expand Down Expand Up @@ -197,10 +198,11 @@ def forward(self, x: Tensor) -> Tensor:
return predictions_first.stack(*predictions_rest, dim=1)

def _validate_batch(self, batch: dict[str, Tensor], _batch_idx: int) -> dict:
"""Check a batch for NaN/Inf values and return statistics."""
"""Check a batch for NaN/Inf values, shape consistency, rank, and dtype."""
numpy_cache = {key: tensor.numpy() for key, tensor in batch.items()}

issues = {}
for key, tensor in batch.items():
data = tensor.numpy()
for key, data in numpy_cache.items():
nan_count = int(np.isnan(data).sum())
inf_count = int(np.isinf(data).sum())
if nan_count > 0 or inf_count > 0:
Expand All @@ -210,6 +212,34 @@ def _validate_batch(self, batch: dict[str, Tensor], _batch_idx: int) -> dict:
"total_elements": data.size,
"nan_pct": f"{(nan_count / data.size) * 100:.2f}%",
}

feature_keys = [k for k in batch if k != "targets"]

batch_sizes = {k: numpy_cache[k].shape[0] for k in feature_keys}
if len(set(batch_sizes.values())) > 1:
issues["shape_mismatch"] = {"batch_sizes": batch_sizes}

for key in feature_keys:
data = numpy_cache[key]
if data.ndim != 3: # noqa: PLR2004
issues.setdefault("rank_errors", {})[key] = {
"ndim": data.ndim,
"expected": 3,
}

for key in feature_keys:
data = numpy_cache[key]
if "continuous" in key and data.dtype != np.float32:
issues.setdefault("dtype_errors", {})[key] = {
"dtype": str(data.dtype),
"expected": "float32",
}
if "categorical" in key and not np.issubdtype(data.dtype, np.integer):
issues.setdefault("dtype_errors", {})[key] = {
"dtype": str(data.dtype),
"expected": "integer",
}

return issues

def validate_training_data(
Expand Down Expand Up @@ -261,6 +291,7 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901
validation_sample_size: int = 10,
early_stopping_patience: int | None = 3,
early_stopping_min_delta: float = 0.001,
checkpoint_directory: str | None = None,
Comment thread
forstmeier marked this conversation as resolved.
) -> list:
"""Train the TiDE model using quantile loss.

Expand All @@ -275,10 +306,13 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901
Set to False to skip validation if data quality is already guaranteed.
validation_sample_size: Number of batches to sample during validation.
Larger values provide more thorough validation but take longer.
The first and last batches are always checked. Default is 10.
Default is 10.
early_stopping_patience: Stop if no improvement for N epochs
(None to disable)
early_stopping_min_delta: Minimum improvement to reset patience counter
checkpoint_directory: Directory to save best-loss checkpoints during
training. If provided, the best checkpoint is automatically restored
after training completes. Defaults to None (no checkpointing).

Performance Notes:
- Data validation runs once before training starts
Expand All @@ -293,6 +327,10 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901
message = "validation_sample_size must be positive"
raise ValueError(message)

if not train_batches:
message = "train_batches must not be empty"
raise ValueError(message)

if validate_data:
is_valid = self.validate_training_data(
train_batches,
Expand All @@ -311,7 +349,9 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901
total_batches = len(train_batches)

best_loss = float("inf")
best_saved_loss = float("inf")
epochs_without_improvement = 0
best_checkpoint_saved = False

try:
for epoch in range(epochs):
Expand Down Expand Up @@ -376,33 +416,55 @@ def train( # noqa: PLR0913, PLR0912, PLR0915, C901

losses.append(epoch_loss)

if early_stopping_patience is not None:
if epoch_loss < best_loss - early_stopping_min_delta:
best_loss = epoch_loss
epochs_without_improvement = 0
logger.info(
"New best loss",
best_loss=f"{best_loss:.4f}",
)
else:
epochs_without_improvement += 1
if epoch_loss < best_loss - early_stopping_min_delta:
best_loss = epoch_loss
epochs_without_improvement = 0
logger.info(
"New best loss",
best_loss=f"{best_loss:.4f}",
)
else:
epochs_without_improvement += 1
if early_stopping_patience is not None:
Comment thread
forstmeier marked this conversation as resolved.
logger.info(
"No improvement",
epochs_without_improvement=epochs_without_improvement,
patience=early_stopping_patience,
)

if epochs_without_improvement >= early_stopping_patience:
logger.info(
"Early stopping triggered",
epoch=epoch + 1,
best_loss=f"{best_loss:.4f}",
epochs_without_improvement=epochs_without_improvement,
)
break
if checkpoint_directory is not None and epoch_loss < best_saved_loss:
best_saved_loss = epoch_loss
Path(checkpoint_directory).mkdir(parents=True, exist_ok=True)
safe_save(
get_state_dict(self),
str(Path(checkpoint_directory) / "tide_states.safetensor"),
)
best_checkpoint_saved = True

if (
early_stopping_patience is not None
and epochs_without_improvement >= early_stopping_patience
):
logger.info(
"Early stopping triggered",
epoch=epoch + 1,
best_loss=f"{best_loss:.4f}",
epochs_without_improvement=epochs_without_improvement,
)
break
finally:
Tensor.training = prev_training

if checkpoint_directory is not None and best_checkpoint_saved:
load_state_dict(
self,
safe_load(str(Path(checkpoint_directory) / "tide_states.safetensor")),
)
logger.info(
"Restored best checkpoint weights",
checkpoint_directory=checkpoint_directory,
)

return losses

def validate(self, validation_batches: list) -> float:
Expand Down Expand Up @@ -484,20 +546,11 @@ def _combine_input_features(
self,
inputs: dict[str, Tensor],
) -> tuple[Tensor, Tensor | None, int]:
batch_size = inputs["encoder_continuous_features"].shape[0]
batch_size = inputs["past_continuous_features"].shape[0]

encoder_cont_flat = inputs["encoder_continuous_features"].reshape(
batch_size, -1
)
encoder_cat_flat = (
inputs["encoder_categorical_features"]
.reshape(batch_size, -1)
.cast("float32")
)
decoder_cat_flat = (
inputs["decoder_categorical_features"]
.reshape(batch_size, -1)
.cast("float32")
past_cont_flat = inputs["past_continuous_features"].reshape(batch_size, -1)
past_cat_flat = (
inputs["past_categorical_features"].reshape(batch_size, -1).cast("float32")
)
static_cat_flat = (
inputs["static_categorical_features"]
Expand All @@ -506,13 +559,7 @@ def _combine_input_features(
)

return (
Tensor.cat(
encoder_cont_flat,
encoder_cat_flat,
decoder_cat_flat,
static_cat_flat,
dim=1,
),
Tensor.cat(past_cont_flat, past_cat_flat, static_cat_flat, dim=1),
inputs.get("targets"),
int(batch_size),
)
21 changes: 8 additions & 13 deletions applications/equitypricemodel/src/equitypricemodel/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import structlog
from equitypricemodel.tide_data import Data
from equitypricemodel.tide_model import Model
from tinygrad import Device
from tinygrad.device import Device

# Configure structlog for CloudWatch-friendly output
structlog.configure(
Expand Down Expand Up @@ -102,29 +102,23 @@

sample_batch = train_batches[0]

batch_size = sample_batch["encoder_continuous_features"].shape[0]
batch_size = sample_batch["past_continuous_features"].shape[0]
logger.info("batch_size_determined", batch_size=batch_size)

# calculate each component's flattened size - days * features (e.g. 35 * 7)
encoder_continuous_size = (
sample_batch["encoder_continuous_features"].reshape(batch_size, -1).shape[1]
past_continuous_size = (
sample_batch["past_continuous_features"].reshape(batch_size, -1).shape[1]
)
encoder_categorical_size = (
sample_batch["encoder_categorical_features"].reshape(batch_size, -1).shape[1]
)
decoder_categorical_size = (
sample_batch["decoder_categorical_features"].reshape(batch_size, -1).shape[1]
past_categorical_size = (
sample_batch["past_categorical_features"].reshape(batch_size, -1).shape[1]
)
static_categorical_size = (
sample_batch["static_categorical_features"].reshape(batch_size, -1).shape[1]
)

input_size = cast(
"int",
encoder_continuous_size
+ encoder_categorical_size
+ decoder_categorical_size
+ static_categorical_size,
past_continuous_size + past_categorical_size + static_categorical_size,
)

logger.info("input_size_calculated", input_size=input_size)
Expand All @@ -146,6 +140,7 @@
train_batches=train_batches,
epochs=int(configuration["epoch_count"]),
learning_rate=float(configuration["learning_rate"]),
checkpoint_directory=model_output_path,
)

logger.info(
Expand Down
Loading
Loading