Skip to content

Commit

Permalink
Merge pull request #66 from actuarialopensource/compiled-model
Browse files Browse the repository at this point in the history
Compiled model
  • Loading branch information
MatthewCaseres authored May 11, 2024
2 parents 25a42fb + f9b6510 commit 43a2a31
Show file tree
Hide file tree
Showing 17 changed files with 165 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
name: basicterm_me_pytorch
name: basicterm_me_python

on:
workflow_dispatch:
push:
paths:
- 'containers/BasicTerm_ME_heavylight/**'
- 'containers/BasicTerm_ME_python/**'

jobs:
docker:
Expand All @@ -28,6 +28,6 @@ jobs:
name: Build and push
uses: docker/build-push-action@v5
with:
context: "{{defaultContext}}:containers/BasicTerm_ME_heavylight"
context: "{{defaultContext}}:containers/BasicTerm_ME_python"
push: true
tags: actuarial/basicterm_me_pytorch:latest
tags: actuarial/basicterm_me_python:latest
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ We currently only have 1 benchmark, we are working to expand the benchmarks. Ope

Time measurement is currently the best of 3 runs.

| benchmark | recursive | container |A100-SXM4-40GB | H100-SXM5-80GB |
| benchmark | classification | container |A100-SXM4-40GB | H100-SXM5-80GB |
|---------------|-|-|----------------|----------------|
| BasicTerm_ME 100 Million | Yes | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_pytorch/general) | 15.8194s | 7.1820s |
| BasicTerm_ME 100 Million | recursive PyTorch | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 15.8284s | 7.205s |
| BasicTerm_ME 100 Million | compiled iterative JAX | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 3.448s | 1.551s |


### Notes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime
# Set the working directory in the container
WORKDIR /app

RUN python -m pip install \
RUN pip install --upgrade pip
RUN pip install --upgrade "jax[cuda12]"
RUN pip install \
pandas \
openpyxl \
equinox \
heavylight==1.0.6

# Copy the rest of the application
Expand Down
22 changes: 22 additions & 0 deletions containers/BasicTerm_ME_python/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import argparse

def main():
parser = argparse.ArgumentParser(description="Term ME model runner")
parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points")
# add an argument that must be either "torch" or "jax"
parser.add_argument("--model", type=str, default="jax_iterative", choices=["torch_recursive", "jax_iterative"], help="Model to run")
args = parser.parse_args()

multiplier = args.multiplier

if args.model == "torch_recursive":
from term_me_recursive_pytorch import time_recursive_PyTorch # having both imports at top level gave a jax error?
time_recursive_PyTorch(multiplier)
elif args.model == "jax_iterative":
from term_me_iterative_jax import time_iterative_jax
time_iterative_jax(multiplier)
else:
raise ValueError("Invalid model")

if __name__ == "__main__":
main()
File renamed without changes.
127 changes: 127 additions & 0 deletions containers/BasicTerm_ME_python/term_me_iterative_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import jax
import pandas as pd
import numpy as np
import timeit
import jax.numpy as jnp
import equinox as eqx
jax.config.update("jax_enable_x64", True)

disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0)
mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0)
model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0)
premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1])

class ModelPointsEqx(eqx.Module):
premium_pp: jnp.ndarray
duration_mth: jnp.ndarray
age_at_entry: jnp.ndarray
sum_assured: jnp.ndarray
policy_count: jnp.ndarray
policy_term: jnp.ndarray
max_proj_len: jnp.ndarray

def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1):
table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True)
table.sort_values(by="policy_id", inplace=True)
self.premium_pp = jnp.round(jnp.array(np.tile(table["sum_assured"].to_numpy() * table["premium_rate"].to_numpy(), size_multiplier)),decimals=2)
self.duration_mth = jnp.array(jnp.tile(table["duration_mth"].to_numpy(), size_multiplier))
self.age_at_entry = jnp.array(jnp.tile(table["age_at_entry"].to_numpy(), size_multiplier))
self.sum_assured = jnp.array(jnp.tile(table["sum_assured"].to_numpy(), size_multiplier))
self.policy_count = jnp.array(jnp.tile(table["policy_count"].to_numpy(), size_multiplier))
self.policy_term = jnp.array(jnp.tile(table["policy_term"].to_numpy(), size_multiplier))
self.max_proj_len = jnp.max(12 * self.policy_term - self.duration_mth) + 1

class AssumptionsEqx(eqx.Module):
disc_rate_ann: jnp.ndarray
mort_table: jnp.ndarray
expense_acq: jnp.ndarray
expense_maint: jnp.ndarray

def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame):
self.disc_rate_ann = jnp.array(disc_rate_ann["zero_spot"].to_numpy())
self.mort_table = jnp.array(mort_table.to_numpy())
self.expense_acq = jnp.array(300)
self.expense_maint = jnp.array(60)

class LoopState(eqx.Module):
t: jnp.ndarray
tot: jnp.ndarray
pols_lapse_prev: jnp.ndarray
pols_death_prev: jnp.ndarray
pols_if_at_BEF_DECR_prev: jnp.ndarray

class TermME(eqx.Module):
mp: ModelPointsEqx
assume: AssumptionsEqx
init_ls: LoopState

def __init__(self, mp: ModelPointsEqx, assume: AssumptionsEqx):
self.mp = mp
self.assume = assume
self.init_ls = LoopState(
t=jnp.array(0),
tot = jnp.array(0),
pols_lapse_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64),
pols_death_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64),
pols_if_at_BEF_DECR_prev=jnp.where(self.mp.duration_mth > 0, self.mp.policy_count, 0.)
)

def __call__(self):
def iterative_core(ls: LoopState, _):
duration_month_t = self.mp.duration_mth + ls.t
duration_t = duration_month_t // 12
age_t = self.mp.age_at_entry + duration_t
pols_if_init = ls.pols_if_at_BEF_DECR_prev - ls.pols_lapse_prev - ls.pols_death_prev
pols_if_at_BEF_MAT = pols_if_init
pols_maturity = (duration_month_t == self.mp.policy_term * 12) * pols_if_at_BEF_MAT
pols_if_at_BEF_NB = pols_if_at_BEF_MAT - pols_maturity
pols_new_biz = jnp.where(duration_month_t == 0, self.mp.policy_count, 0)
pols_if_at_BEF_DECR = pols_if_at_BEF_NB + pols_new_biz
mort_rate = self.assume.mort_table[age_t-18, jnp.clip(duration_t, a_max=5)]
mort_rate_mth = 1 - (1 - mort_rate) ** (1/12)
pols_death = pols_if_at_BEF_DECR * mort_rate_mth
claims = self.mp.sum_assured * pols_death
premiums = self.mp.premium_pp * pols_if_at_BEF_DECR
commissions = (duration_t == 0) * premiums
discount = (1 + self.assume.disc_rate_ann[ls.t//12]) ** (-ls.t/12)
inflation_factor = (1 + 0.01) ** (ls.t/12)
expenses = self.assume.expense_acq * pols_new_biz + pols_if_at_BEF_DECR * self.assume.expense_maint/12 * inflation_factor
lapse_rate = jnp.clip(0.1 - 0.02 * duration_t, a_min=0.02)
net_cf = premiums - claims - expenses - commissions
discounted_net_cf = jnp.sum(net_cf) * discount
nxt_ls = LoopState(
t=ls.t+1,
tot = ls.tot + discounted_net_cf,
pols_lapse_prev=(pols_if_at_BEF_DECR - pols_death) * (1 - (1 - lapse_rate) ** (1/12)),
pols_death_prev=pols_death,
pols_if_at_BEF_DECR_prev=pols_if_at_BEF_DECR
)
return nxt_ls, None
return jax.lax.scan(iterative_core, self.init_ls, xs=None, length=277)[0].tot


def run_jax_term_ME(term_me: TermME):
return term_me()

run_jax_term_ME_opt = jax.jit(run_jax_term_ME)

def time_jax_func(mp, assume, func):
term_me = TermME(mp, assume)
result = func(term_me).block_until_ready()
start = timeit.default_timer()
result = func(term_me).block_until_ready()
end = timeit.default_timer()
elapsed_time = end - start # Time in seconds
return float(result), elapsed_time

def time_iterative_jax(multiplier: int):
mp = ModelPointsEqx(model_point_table, premium_table, size_multiplier=multiplier)
assume = AssumptionsEqx(disc_rate_ann, mort_table)
result, time_in_seconds = time_jax_func(mp, assume, run_jax_term_ME_opt)
print("JAX iterative model")
print(f"number modelpoints={len(mp.duration_mth):,}")
print(f"{result=:,}")
print(f"{time_in_seconds=}")

if __name__ == "__main__":
time_iterative_jax(100)
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import torch
from heavylight import LightModel, agg
import timeit
import argparse

# ensure CUDA available
print(f"{torch.cuda.is_available()=}")
# set 64 bit precision
torch.set_default_dtype(torch.float64)
Expand Down Expand Up @@ -167,16 +165,8 @@ def time_recursive_CPU(model: TermME):
end = timeit.default_timer()
return result, end - start

def main():
parser = argparse.ArgumentParser(description="Term ME model runner")
parser.add_argument("--disable_cuda", action="store_true", help="Disable CUDA usage")
parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points")
args = parser.parse_args()

disable_cuda = args.disable_cuda
multiplier = args.multiplier

if not disable_cuda and torch.cuda.is_available():
def time_recursive_PyTorch(multiplier: int):
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
Expand All @@ -196,9 +186,7 @@ def main():
model.mp = mp_multiplied
result, time_in_seconds = time_recursive(model)
# report results
print(f"number modelpoints={len(model_point_table) * multiplier:,}")
print("PyTorch recursive model")
print(f"number modelpoints={len(mp_multiplied.duration_mth):,}")
print(f"{result=:,}")
print(f"{time_in_seconds=}")

if __name__ == "__main__":
main()
print(f"{time_in_seconds=}")

0 comments on commit 43a2a31

Please sign in to comment.