-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #66 from actuarialopensource/compiled-model
Compiled model
- Loading branch information
Showing
17 changed files
with
165 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import argparse | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Term ME model runner") | ||
parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points") | ||
# add an argument that must be either "torch" or "jax" | ||
parser.add_argument("--model", type=str, default="jax_iterative", choices=["torch_recursive", "jax_iterative"], help="Model to run") | ||
args = parser.parse_args() | ||
|
||
multiplier = args.multiplier | ||
|
||
if args.model == "torch_recursive": | ||
from term_me_recursive_pytorch import time_recursive_PyTorch # having both imports at top level gave a jax error? | ||
time_recursive_PyTorch(multiplier) | ||
elif args.model == "jax_iterative": | ||
from term_me_iterative_jax import time_iterative_jax | ||
time_iterative_jax(multiplier) | ||
else: | ||
raise ValueError("Invalid model") | ||
|
||
if __name__ == "__main__": | ||
main() |
File renamed without changes.
File renamed without changes.
127 changes: 127 additions & 0 deletions
127
containers/BasicTerm_ME_python/term_me_iterative_jax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import jax | ||
import pandas as pd | ||
import numpy as np | ||
import timeit | ||
import jax.numpy as jnp | ||
import equinox as eqx | ||
jax.config.update("jax_enable_x64", True) | ||
|
||
disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0) | ||
mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0) | ||
model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0) | ||
premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1]) | ||
|
||
class ModelPointsEqx(eqx.Module): | ||
premium_pp: jnp.ndarray | ||
duration_mth: jnp.ndarray | ||
age_at_entry: jnp.ndarray | ||
sum_assured: jnp.ndarray | ||
policy_count: jnp.ndarray | ||
policy_term: jnp.ndarray | ||
max_proj_len: jnp.ndarray | ||
|
||
def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1): | ||
table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True) | ||
table.sort_values(by="policy_id", inplace=True) | ||
self.premium_pp = jnp.round(jnp.array(np.tile(table["sum_assured"].to_numpy() * table["premium_rate"].to_numpy(), size_multiplier)),decimals=2) | ||
self.duration_mth = jnp.array(jnp.tile(table["duration_mth"].to_numpy(), size_multiplier)) | ||
self.age_at_entry = jnp.array(jnp.tile(table["age_at_entry"].to_numpy(), size_multiplier)) | ||
self.sum_assured = jnp.array(jnp.tile(table["sum_assured"].to_numpy(), size_multiplier)) | ||
self.policy_count = jnp.array(jnp.tile(table["policy_count"].to_numpy(), size_multiplier)) | ||
self.policy_term = jnp.array(jnp.tile(table["policy_term"].to_numpy(), size_multiplier)) | ||
self.max_proj_len = jnp.max(12 * self.policy_term - self.duration_mth) + 1 | ||
|
||
class AssumptionsEqx(eqx.Module): | ||
disc_rate_ann: jnp.ndarray | ||
mort_table: jnp.ndarray | ||
expense_acq: jnp.ndarray | ||
expense_maint: jnp.ndarray | ||
|
||
def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame): | ||
self.disc_rate_ann = jnp.array(disc_rate_ann["zero_spot"].to_numpy()) | ||
self.mort_table = jnp.array(mort_table.to_numpy()) | ||
self.expense_acq = jnp.array(300) | ||
self.expense_maint = jnp.array(60) | ||
|
||
class LoopState(eqx.Module): | ||
t: jnp.ndarray | ||
tot: jnp.ndarray | ||
pols_lapse_prev: jnp.ndarray | ||
pols_death_prev: jnp.ndarray | ||
pols_if_at_BEF_DECR_prev: jnp.ndarray | ||
|
||
class TermME(eqx.Module): | ||
mp: ModelPointsEqx | ||
assume: AssumptionsEqx | ||
init_ls: LoopState | ||
|
||
def __init__(self, mp: ModelPointsEqx, assume: AssumptionsEqx): | ||
self.mp = mp | ||
self.assume = assume | ||
self.init_ls = LoopState( | ||
t=jnp.array(0), | ||
tot = jnp.array(0), | ||
pols_lapse_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), | ||
pols_death_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), | ||
pols_if_at_BEF_DECR_prev=jnp.where(self.mp.duration_mth > 0, self.mp.policy_count, 0.) | ||
) | ||
|
||
def __call__(self): | ||
def iterative_core(ls: LoopState, _): | ||
duration_month_t = self.mp.duration_mth + ls.t | ||
duration_t = duration_month_t // 12 | ||
age_t = self.mp.age_at_entry + duration_t | ||
pols_if_init = ls.pols_if_at_BEF_DECR_prev - ls.pols_lapse_prev - ls.pols_death_prev | ||
pols_if_at_BEF_MAT = pols_if_init | ||
pols_maturity = (duration_month_t == self.mp.policy_term * 12) * pols_if_at_BEF_MAT | ||
pols_if_at_BEF_NB = pols_if_at_BEF_MAT - pols_maturity | ||
pols_new_biz = jnp.where(duration_month_t == 0, self.mp.policy_count, 0) | ||
pols_if_at_BEF_DECR = pols_if_at_BEF_NB + pols_new_biz | ||
mort_rate = self.assume.mort_table[age_t-18, jnp.clip(duration_t, a_max=5)] | ||
mort_rate_mth = 1 - (1 - mort_rate) ** (1/12) | ||
pols_death = pols_if_at_BEF_DECR * mort_rate_mth | ||
claims = self.mp.sum_assured * pols_death | ||
premiums = self.mp.premium_pp * pols_if_at_BEF_DECR | ||
commissions = (duration_t == 0) * premiums | ||
discount = (1 + self.assume.disc_rate_ann[ls.t//12]) ** (-ls.t/12) | ||
inflation_factor = (1 + 0.01) ** (ls.t/12) | ||
expenses = self.assume.expense_acq * pols_new_biz + pols_if_at_BEF_DECR * self.assume.expense_maint/12 * inflation_factor | ||
lapse_rate = jnp.clip(0.1 - 0.02 * duration_t, a_min=0.02) | ||
net_cf = premiums - claims - expenses - commissions | ||
discounted_net_cf = jnp.sum(net_cf) * discount | ||
nxt_ls = LoopState( | ||
t=ls.t+1, | ||
tot = ls.tot + discounted_net_cf, | ||
pols_lapse_prev=(pols_if_at_BEF_DECR - pols_death) * (1 - (1 - lapse_rate) ** (1/12)), | ||
pols_death_prev=pols_death, | ||
pols_if_at_BEF_DECR_prev=pols_if_at_BEF_DECR | ||
) | ||
return nxt_ls, None | ||
return jax.lax.scan(iterative_core, self.init_ls, xs=None, length=277)[0].tot | ||
|
||
|
||
def run_jax_term_ME(term_me: TermME): | ||
return term_me() | ||
|
||
run_jax_term_ME_opt = jax.jit(run_jax_term_ME) | ||
|
||
def time_jax_func(mp, assume, func): | ||
term_me = TermME(mp, assume) | ||
result = func(term_me).block_until_ready() | ||
start = timeit.default_timer() | ||
result = func(term_me).block_until_ready() | ||
end = timeit.default_timer() | ||
elapsed_time = end - start # Time in seconds | ||
return float(result), elapsed_time | ||
|
||
def time_iterative_jax(multiplier: int): | ||
mp = ModelPointsEqx(model_point_table, premium_table, size_multiplier=multiplier) | ||
assume = AssumptionsEqx(disc_rate_ann, mort_table) | ||
result, time_in_seconds = time_jax_func(mp, assume, run_jax_term_ME_opt) | ||
print("JAX iterative model") | ||
print(f"number modelpoints={len(mp.duration_mth):,}") | ||
print(f"{result=:,}") | ||
print(f"{time_in_seconds=}") | ||
|
||
if __name__ == "__main__": | ||
time_iterative_jax(100) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters