From 6cfa8b493a0385d8cdb25dd47902252ae79a8770 Mon Sep 17 00:00:00 2001 From: Matthew Caseres Date: Sat, 11 May 2024 12:49:50 +0000 Subject: [PATCH 1/4] container no longer PyTorch specific, rename --- ...cTerm_ME_PyTorch.yml => BasicTerm_ME_python.yml} | 8 ++++---- .../.devcontainer/devcontainer.json | 0 .../BasicTerm_ME/Projection/__init__.py | 0 .../BasicTerm_ME/__init__.py | 0 .../BasicTerm_ME/_data/data.pickle | Bin .../BasicTerm_ME/_system.json | 0 .../BasicTerm_ME/disc_rate_ann.xlsx | Bin .../BasicTerm_ME/model_point_table.xlsx | Bin .../BasicTerm_ME/mort_table.xlsx | Bin .../BasicTerm_ME/premium_table.xlsx | Bin .../Dockerfile | 1 + .../main.py | 0 .../notes.md | 0 .../requirements.txt | 0 14 files changed, 5 insertions(+), 4 deletions(-) rename .github/workflows/{BasicTerm_ME_PyTorch.yml => BasicTerm_ME_python.yml} (77%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/.devcontainer/devcontainer.json (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/Projection/__init__.py (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/__init__.py (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/_data/data.pickle (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/_system.json (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/disc_rate_ann.xlsx (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/model_point_table.xlsx (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/mort_table.xlsx (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/BasicTerm_ME/premium_table.xlsx (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/Dockerfile (98%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/main.py (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/notes.md (100%) rename containers/{BasicTerm_ME_heavylight => BasicTerm_ME_python}/requirements.txt (100%) diff --git a/.github/workflows/BasicTerm_ME_PyTorch.yml b/.github/workflows/BasicTerm_ME_python.yml similarity index 77% rename from .github/workflows/BasicTerm_ME_PyTorch.yml rename to .github/workflows/BasicTerm_ME_python.yml index 9ae2bfe..3a2cf49 100644 --- a/.github/workflows/BasicTerm_ME_PyTorch.yml +++ b/.github/workflows/BasicTerm_ME_python.yml @@ -1,10 +1,10 @@ -name: basicterm_me_pytorch +name: basicterm_me_python on: workflow_dispatch: push: paths: - - 'containers/BasicTerm_ME_heavylight/**' + - 'containers/BasicTerm_ME_python/**' jobs: docker: @@ -28,6 +28,6 @@ jobs: name: Build and push uses: docker/build-push-action@v5 with: - context: "{{defaultContext}}:containers/BasicTerm_ME_heavylight" + context: "{{defaultContext}}:containers/BasicTerm_ME_python" push: true - tags: actuarial/basicterm_me_pytorch:latest \ No newline at end of file + tags: actuarial/basicterm_me_python:latest \ No newline at end of file diff --git a/containers/BasicTerm_ME_heavylight/.devcontainer/devcontainer.json b/containers/BasicTerm_ME_python/.devcontainer/devcontainer.json similarity index 100% rename from containers/BasicTerm_ME_heavylight/.devcontainer/devcontainer.json rename to containers/BasicTerm_ME_python/.devcontainer/devcontainer.json diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/Projection/__init__.py b/containers/BasicTerm_ME_python/BasicTerm_ME/Projection/__init__.py similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/Projection/__init__.py rename to containers/BasicTerm_ME_python/BasicTerm_ME/Projection/__init__.py diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/__init__.py b/containers/BasicTerm_ME_python/BasicTerm_ME/__init__.py similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/__init__.py rename to containers/BasicTerm_ME_python/BasicTerm_ME/__init__.py diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/_data/data.pickle b/containers/BasicTerm_ME_python/BasicTerm_ME/_data/data.pickle similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/_data/data.pickle rename to containers/BasicTerm_ME_python/BasicTerm_ME/_data/data.pickle diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/_system.json b/containers/BasicTerm_ME_python/BasicTerm_ME/_system.json similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/_system.json rename to containers/BasicTerm_ME_python/BasicTerm_ME/_system.json diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/disc_rate_ann.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/disc_rate_ann.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/disc_rate_ann.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/disc_rate_ann.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/model_point_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/model_point_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/model_point_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/model_point_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/mort_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/mort_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/mort_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/mort_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/BasicTerm_ME/premium_table.xlsx b/containers/BasicTerm_ME_python/BasicTerm_ME/premium_table.xlsx similarity index 100% rename from containers/BasicTerm_ME_heavylight/BasicTerm_ME/premium_table.xlsx rename to containers/BasicTerm_ME_python/BasicTerm_ME/premium_table.xlsx diff --git a/containers/BasicTerm_ME_heavylight/Dockerfile b/containers/BasicTerm_ME_python/Dockerfile similarity index 98% rename from containers/BasicTerm_ME_heavylight/Dockerfile rename to containers/BasicTerm_ME_python/Dockerfile index 3f7836d..7a1c5b5 100644 --- a/containers/BasicTerm_ME_heavylight/Dockerfile +++ b/containers/BasicTerm_ME_python/Dockerfile @@ -7,6 +7,7 @@ WORKDIR /app RUN python -m pip install \ pandas \ openpyxl \ + jax \ heavylight==1.0.6 # Copy the rest of the application diff --git a/containers/BasicTerm_ME_heavylight/main.py b/containers/BasicTerm_ME_python/main.py similarity index 100% rename from containers/BasicTerm_ME_heavylight/main.py rename to containers/BasicTerm_ME_python/main.py diff --git a/containers/BasicTerm_ME_heavylight/notes.md b/containers/BasicTerm_ME_python/notes.md similarity index 100% rename from containers/BasicTerm_ME_heavylight/notes.md rename to containers/BasicTerm_ME_python/notes.md diff --git a/containers/BasicTerm_ME_heavylight/requirements.txt b/containers/BasicTerm_ME_python/requirements.txt similarity index 100% rename from containers/BasicTerm_ME_heavylight/requirements.txt rename to containers/BasicTerm_ME_python/requirements.txt From 6f46efdb1098910f014df751368144f04534c662 Mon Sep 17 00:00:00 2001 From: m Date: Sat, 11 May 2024 14:36:00 +0000 Subject: [PATCH 2/4] setup benchmarking code --- containers/BasicTerm_ME_python/Dockerfile | 4 +- containers/BasicTerm_ME_python/main.py | 200 +----------------- .../term_me_iterative_jax.py | 127 +++++++++++ .../term_me_recursive_pytorch.py | 192 +++++++++++++++++ 4 files changed, 331 insertions(+), 192 deletions(-) create mode 100644 containers/BasicTerm_ME_python/term_me_iterative_jax.py create mode 100644 containers/BasicTerm_ME_python/term_me_recursive_pytorch.py diff --git a/containers/BasicTerm_ME_python/Dockerfile b/containers/BasicTerm_ME_python/Dockerfile index 7a1c5b5..9f50551 100644 --- a/containers/BasicTerm_ME_python/Dockerfile +++ b/containers/BasicTerm_ME_python/Dockerfile @@ -4,10 +4,12 @@ FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime # Set the working directory in the container WORKDIR /app +RUN python -m pip install --upgrade pip +RUN python -m pip install --upgrade "jax[cuda12]" RUN python -m pip install \ pandas \ openpyxl \ - jax \ + equinox \ heavylight==1.0.6 # Copy the rest of the application diff --git a/containers/BasicTerm_ME_python/main.py b/containers/BasicTerm_ME_python/main.py index cc15963..ba13c04 100644 --- a/containers/BasicTerm_ME_python/main.py +++ b/containers/BasicTerm_ME_python/main.py @@ -1,204 +1,22 @@ -import torch -from collections import defaultdict -import pandas as pd -import numpy as np -import torch -from heavylight import LightModel, agg -import timeit import argparse -# ensure CUDA available -print(f"{torch.cuda.is_available()=}") -# set 64 bit precision -torch.set_default_dtype(torch.float64) -print(f"{torch.get_default_dtype()=}") - -disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0) -mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0) -model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0) -premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1]) - -class ModelPoints: - def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1): - self.table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True) - self.table.sort_values(by="policy_id", inplace=True) - self.premium_pp = torch.round(torch.tensor(np.tile(self.table["sum_assured"].to_numpy() * self.table["premium_rate"].to_numpy(), size_multiplier)),decimals=2) - self.duration_mth = torch.tensor(np.tile(self.table["duration_mth"].to_numpy(), size_multiplier)) - self.age_at_entry = torch.tensor(np.tile(self.table["age_at_entry"].to_numpy(), size_multiplier)) - self.sum_assured = torch.tensor(np.tile(self.table["sum_assured"].to_numpy(), size_multiplier)) - self.policy_count = torch.tensor(np.tile(self.table["policy_count"].to_numpy(), size_multiplier)) - self.policy_term = torch.tensor(np.tile(self.table["policy_term"].to_numpy(), size_multiplier)) - self.max_proj_len: int = int(torch.max(12 * self.policy_term - self.duration_mth) + 1) - -class Assumptions: - def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame): - self.disc_rate_ann = torch.tensor(disc_rate_ann["zero_spot"].to_numpy()) - self.mort_table = torch.tensor(mort_table.to_numpy()) - - def get_mortality(self, age, duration): - return self.mort_table[age-18, torch.clamp(duration, max=5)] - -agg_func = lambda x: float(torch.sum(x)) - -class TermME(LightModel): - def __init__(self, mp: ModelPoints, assume: Assumptions): - super().__init__(agg_function=None) - self.mp = mp - self.assume = assume - - def age(self, t): - return self.mp.age_at_entry + self.duration(t) - - def claim_pp(self, t): - return self.mp.sum_assured - - def claims(self, t): - return self.claim_pp(t) * self.pols_death(t) - - def commissions(self, t): - return (self.duration(t) == 0) * self.premiums(t) - - def disc_factors(self): - return torch.tensor(list((1 + self.disc_rate_mth()[t])**(-t) for t in range(self.mp.max_proj_len))) - - def discount(self, t: int): - return (1 + self.assume.disc_rate_ann[t//12]) ** (-t/12) - - def disc_rate_mth(self): - return torch.tensor(list((1 + self.assume.disc_rate_ann[t//12])**(1/12) - 1 for t in range(self.mp.max_proj_len))) - - def duration(self, t): - return self.duration_mth(t) // 12 - - def duration_mth(self, t): - if t == 0: - return self.mp.duration_mth - else: - return self.duration_mth(t-1) + 1 - - def expense_acq(self): - return 300 - - def expense_maint(self): - return 60 - - def expenses(self, t): - return self.expense_acq() * self.pols_new_biz(t) \ - + self.pols_if_at(t, "BEF_DECR") * self.expense_maint()/12 * self.inflation_factor(t) - - def inflation_factor(self, t): - return (1 + self.inflation_rate())**(t/12) - - def inflation_rate(self): - return 0.01 - - def lapse_rate(self, t): - return torch.clamp(0.1 - 0.02 * self.duration(t), min=0.02) - - def loading_prem(self): - return 0.5 - - def mort_rate(self, t): - return self.assume.get_mortality(self.age(t), self.duration(t)) - - def mort_rate_mth(self, t): - return 1-(1- self.mort_rate(t))**(1/12) - - def net_cf(self, t): - return self.premiums(t) - self.claims(t) - self.expenses(t) - self.commissions(t) - - def pols_death(self, t): - return self.pols_if_at(t, "BEF_DECR") * self.mort_rate_mth(t) - - @agg(agg_func) - def discounted_net_cf(self, t): - return torch.sum(self.net_cf(t)) * self.discount(t) - - def pols_if_at(self, t, timing): - if timing == "BEF_MAT": - if t == 0: - return self.pols_if_init() - else: - return self.pols_if_at(t-1, "BEF_DECR") - self.pols_lapse(t-1) - self.pols_death(t-1) - elif timing == "BEF_NB": - return self.pols_if_at(t, "BEF_MAT") - self.pols_maturity(t) - elif timing == "BEF_DECR": - return self.pols_if_at(t, "BEF_NB") + self.pols_new_biz(t) - else: - raise ValueError("invalid timing") - - def pols_if_init(self): - return torch.where(self.duration_mth(0) > 0, self.mp.policy_count, 0) - - def pols_lapse(self, t): - return (self.pols_if_at(t, "BEF_DECR") - self.pols_death(t)) * (1-(1 - self.lapse_rate(t))**(1/12)) - - def pols_maturity(self, t): - return (self.duration_mth(t) == self.mp.policy_term * 12) * self.pols_if_at(t, "BEF_MAT") - - def pols_new_biz(self, t): - return torch.where(self.duration_mth(t) == 0, self.mp.policy_count, 0) - - def premiums(self, t): - return self.mp.premium_pp * self.pols_if_at(t, "BEF_DECR") - - -def run_recursive_model(model: TermME): - model.cache_graph._caches = defaultdict(dict) - model.cache_graph._caches_agg = defaultdict(dict) - model.RunModel(model.mp.max_proj_len) - return float(sum(model.cache_agg['discounted_net_cf'].values())) - - -def time_recursive_GPU(model: TermME): - model.OptimizeMemoryAndReset() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - result = run_recursive_model(model) - end.record() - torch.cuda.synchronize() - return result, start.elapsed_time(end) / 1000 - -def time_recursive_CPU(model: TermME): - model.OptimizeMemoryAndReset() - start = timeit.default_timer() - result = run_recursive_model(model) - end = timeit.default_timer() - return result, end - start - def main(): parser = argparse.ArgumentParser(description="Term ME model runner") - parser.add_argument("--disable_cuda", action="store_true", help="Disable CUDA usage") parser.add_argument("--multiplier", type=int, default=100, help="Multiplier for model points") + # add an argument that must be either "torch" or "jax" + parser.add_argument("--model", type=str, default="jax_iterative", choices=["torch_recursive", "jax_iterative"], help="Model to run") args = parser.parse_args() - disable_cuda = args.disable_cuda multiplier = args.multiplier - if not disable_cuda and torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - print(f"{device=}") - - with device: - mp = ModelPoints(model_point_table, premium_table) - mp_multiplied = ModelPoints(model_point_table, premium_table, multiplier) - assume = Assumptions(disc_rate_ann, mort_table) - model = TermME(mp, assume) - - if device.type == 'cuda': - time_recursive = time_recursive_GPU + if args.model == "torch_recursive": + from term_me_recursive_pytorch import time_recursive_PyTorch # having both imports at top level gave a jax error? + time_recursive_PyTorch(multiplier) + elif args.model == "jax_iterative": + from term_me_iterative_jax import time_iterative_jax + time_iterative_jax(multiplier) else: - time_recursive = time_recursive_CPU - run_recursive_model(model) # warm up, generate dependency graph - model.mp = mp_multiplied - result, time_in_seconds = time_recursive(model) - # report results - print(f"number modelpoints={len(model_point_table) * multiplier:,}") - print(f"{result=:,}") - print(f"{time_in_seconds=}") + raise ValueError("Invalid model") if __name__ == "__main__": main() \ No newline at end of file diff --git a/containers/BasicTerm_ME_python/term_me_iterative_jax.py b/containers/BasicTerm_ME_python/term_me_iterative_jax.py new file mode 100644 index 0000000..20b9dbc --- /dev/null +++ b/containers/BasicTerm_ME_python/term_me_iterative_jax.py @@ -0,0 +1,127 @@ +import jax +import pandas as pd +import numpy as np +import timeit +import jax.numpy as jnp +import equinox as eqx +jax.config.update("jax_enable_x64", True) + +disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0) +mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0) +model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0) +premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1]) + +class ModelPointsEqx(eqx.Module): + premium_pp: jnp.ndarray + duration_mth: jnp.ndarray + age_at_entry: jnp.ndarray + sum_assured: jnp.ndarray + policy_count: jnp.ndarray + policy_term: jnp.ndarray + max_proj_len: jnp.ndarray + + def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1): + table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True) + table.sort_values(by="policy_id", inplace=True) + self.premium_pp = jnp.round(jnp.array(np.tile(table["sum_assured"].to_numpy() * table["premium_rate"].to_numpy(), size_multiplier)),decimals=2) + self.duration_mth = jnp.array(jnp.tile(table["duration_mth"].to_numpy(), size_multiplier)) + self.age_at_entry = jnp.array(jnp.tile(table["age_at_entry"].to_numpy(), size_multiplier)) + self.sum_assured = jnp.array(jnp.tile(table["sum_assured"].to_numpy(), size_multiplier)) + self.policy_count = jnp.array(jnp.tile(table["policy_count"].to_numpy(), size_multiplier)) + self.policy_term = jnp.array(jnp.tile(table["policy_term"].to_numpy(), size_multiplier)) + self.max_proj_len = jnp.max(12 * self.policy_term - self.duration_mth) + 1 + +class AssumptionsEqx(eqx.Module): + disc_rate_ann: jnp.ndarray + mort_table: jnp.ndarray + expense_acq: jnp.ndarray + expense_maint: jnp.ndarray + + def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame): + self.disc_rate_ann = jnp.array(disc_rate_ann["zero_spot"].to_numpy()) + self.mort_table = jnp.array(mort_table.to_numpy()) + self.expense_acq = jnp.array(300) + self.expense_maint = jnp.array(60) + +class LoopState(eqx.Module): + t: jnp.ndarray + tot: jnp.ndarray + pols_lapse_prev: jnp.ndarray + pols_death_prev: jnp.ndarray + pols_if_at_BEF_DECR_prev: jnp.ndarray + +class TermME(eqx.Module): + mp: ModelPointsEqx + assume: AssumptionsEqx + init_ls: LoopState + + def __init__(self, mp: ModelPointsEqx, assume: AssumptionsEqx): + self.mp = mp + self.assume = assume + self.init_ls = LoopState( + t=jnp.array(0), + tot = jnp.array(0), + pols_lapse_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), + pols_death_prev=jnp.zeros_like(self.mp.duration_mth, dtype=jnp.float64), + pols_if_at_BEF_DECR_prev=jnp.where(self.mp.duration_mth > 0, self.mp.policy_count, 0.) + ) + + def __call__(self): + def iterative_core(ls: LoopState, _): + duration_month_t = self.mp.duration_mth + ls.t + duration_t = duration_month_t // 12 + age_t = self.mp.age_at_entry + duration_t + pols_if_init = ls.pols_if_at_BEF_DECR_prev - ls.pols_lapse_prev - ls.pols_death_prev + pols_if_at_BEF_MAT = pols_if_init + pols_maturity = (duration_month_t == self.mp.policy_term * 12) * pols_if_at_BEF_MAT + pols_if_at_BEF_NB = pols_if_at_BEF_MAT - pols_maturity + pols_new_biz = jnp.where(duration_month_t == 0, self.mp.policy_count, 0) + pols_if_at_BEF_DECR = pols_if_at_BEF_NB + pols_new_biz + mort_rate = self.assume.mort_table[age_t-18, jnp.clip(duration_t, a_max=5)] + mort_rate_mth = 1 - (1 - mort_rate) ** (1/12) + pols_death = pols_if_at_BEF_DECR * mort_rate_mth + claims = self.mp.sum_assured * pols_death + premiums = self.mp.premium_pp * pols_if_at_BEF_DECR + commissions = (duration_t == 0) * premiums + discount = (1 + self.assume.disc_rate_ann[ls.t//12]) ** (-ls.t/12) + inflation_factor = (1 + 0.01) ** (ls.t/12) + expenses = self.assume.expense_acq * pols_new_biz + pols_if_at_BEF_DECR * self.assume.expense_maint/12 * inflation_factor + lapse_rate = jnp.clip(0.1 - 0.02 * duration_t, a_min=0.02) + net_cf = premiums - claims - expenses - commissions + discounted_net_cf = jnp.sum(net_cf) * discount + nxt_ls = LoopState( + t=ls.t+1, + tot = ls.tot + discounted_net_cf, + pols_lapse_prev=(pols_if_at_BEF_DECR - pols_death) * (1 - (1 - lapse_rate) ** (1/12)), + pols_death_prev=pols_death, + pols_if_at_BEF_DECR_prev=pols_if_at_BEF_DECR + ) + return nxt_ls, None + return jax.lax.scan(iterative_core, self.init_ls, xs=None, length=277)[0].tot + + +def run_jax_term_ME(term_me: TermME): + return term_me() + +run_jax_term_ME_opt = jax.jit(run_jax_term_ME) + +def time_jax_func(mp, assume, func): + term_me = TermME(mp, assume) + result = func(term_me).block_until_ready() + start = timeit.default_timer() + result = func(term_me).block_until_ready() + end = timeit.default_timer() + elapsed_time = end - start # Time in seconds + return float(result), elapsed_time + +def time_iterative_jax(multiplier: int): + mp = ModelPointsEqx(model_point_table, premium_table, size_multiplier=multiplier) + assume = AssumptionsEqx(disc_rate_ann, mort_table) + result, time_in_seconds = time_jax_func(mp, assume, run_jax_term_ME_opt) + print("JAX iterative model") + print(f"number modelpoints={len(mp.duration_mth):,}") + print(f"{result=:,}") + print(f"{time_in_seconds=}") + +if __name__ == "__main__": + time_iterative_jax(100) \ No newline at end of file diff --git a/containers/BasicTerm_ME_python/term_me_recursive_pytorch.py b/containers/BasicTerm_ME_python/term_me_recursive_pytorch.py new file mode 100644 index 0000000..e6f9861 --- /dev/null +++ b/containers/BasicTerm_ME_python/term_me_recursive_pytorch.py @@ -0,0 +1,192 @@ +import torch +from collections import defaultdict +import pandas as pd +import numpy as np +import torch +from heavylight import LightModel, agg +import timeit + +print(f"{torch.cuda.is_available()=}") +# set 64 bit precision +torch.set_default_dtype(torch.float64) +print(f"{torch.get_default_dtype()=}") + +disc_rate_ann = pd.read_excel("BasicTerm_ME/disc_rate_ann.xlsx", index_col=0) +mort_table = pd.read_excel("BasicTerm_ME/mort_table.xlsx", index_col=0) +model_point_table = pd.read_excel("BasicTerm_ME/model_point_table.xlsx", index_col=0) +premium_table = pd.read_excel("BasicTerm_ME/premium_table.xlsx", index_col=[0,1]) + +class ModelPoints: + def __init__(self, model_point_table: pd.DataFrame, premium_table: pd.DataFrame, size_multiplier: int = 1): + self.table = model_point_table.merge(premium_table, left_on=["age_at_entry", "policy_term"], right_index=True) + self.table.sort_values(by="policy_id", inplace=True) + self.premium_pp = torch.round(torch.tensor(np.tile(self.table["sum_assured"].to_numpy() * self.table["premium_rate"].to_numpy(), size_multiplier)),decimals=2) + self.duration_mth = torch.tensor(np.tile(self.table["duration_mth"].to_numpy(), size_multiplier)) + self.age_at_entry = torch.tensor(np.tile(self.table["age_at_entry"].to_numpy(), size_multiplier)) + self.sum_assured = torch.tensor(np.tile(self.table["sum_assured"].to_numpy(), size_multiplier)) + self.policy_count = torch.tensor(np.tile(self.table["policy_count"].to_numpy(), size_multiplier)) + self.policy_term = torch.tensor(np.tile(self.table["policy_term"].to_numpy(), size_multiplier)) + self.max_proj_len: int = int(torch.max(12 * self.policy_term - self.duration_mth) + 1) + +class Assumptions: + def __init__(self, disc_rate_ann: pd.DataFrame, mort_table: pd.DataFrame): + self.disc_rate_ann = torch.tensor(disc_rate_ann["zero_spot"].to_numpy()) + self.mort_table = torch.tensor(mort_table.to_numpy()) + + def get_mortality(self, age, duration): + return self.mort_table[age-18, torch.clamp(duration, max=5)] + +agg_func = lambda x: float(torch.sum(x)) + +class TermME(LightModel): + def __init__(self, mp: ModelPoints, assume: Assumptions): + super().__init__(agg_function=None) + self.mp = mp + self.assume = assume + + def age(self, t): + return self.mp.age_at_entry + self.duration(t) + + def claim_pp(self, t): + return self.mp.sum_assured + + def claims(self, t): + return self.claim_pp(t) * self.pols_death(t) + + def commissions(self, t): + return (self.duration(t) == 0) * self.premiums(t) + + def disc_factors(self): + return torch.tensor(list((1 + self.disc_rate_mth()[t])**(-t) for t in range(self.mp.max_proj_len))) + + def discount(self, t: int): + return (1 + self.assume.disc_rate_ann[t//12]) ** (-t/12) + + def disc_rate_mth(self): + return torch.tensor(list((1 + self.assume.disc_rate_ann[t//12])**(1/12) - 1 for t in range(self.mp.max_proj_len))) + + def duration(self, t): + return self.duration_mth(t) // 12 + + def duration_mth(self, t): + if t == 0: + return self.mp.duration_mth + else: + return self.duration_mth(t-1) + 1 + + def expense_acq(self): + return 300 + + def expense_maint(self): + return 60 + + def expenses(self, t): + return self.expense_acq() * self.pols_new_biz(t) \ + + self.pols_if_at(t, "BEF_DECR") * self.expense_maint()/12 * self.inflation_factor(t) + + def inflation_factor(self, t): + return (1 + self.inflation_rate())**(t/12) + + def inflation_rate(self): + return 0.01 + + def lapse_rate(self, t): + return torch.clamp(0.1 - 0.02 * self.duration(t), min=0.02) + + def loading_prem(self): + return 0.5 + + def mort_rate(self, t): + return self.assume.get_mortality(self.age(t), self.duration(t)) + + def mort_rate_mth(self, t): + return 1-(1- self.mort_rate(t))**(1/12) + + def net_cf(self, t): + return self.premiums(t) - self.claims(t) - self.expenses(t) - self.commissions(t) + + def pols_death(self, t): + return self.pols_if_at(t, "BEF_DECR") * self.mort_rate_mth(t) + + @agg(agg_func) + def discounted_net_cf(self, t): + return torch.sum(self.net_cf(t)) * self.discount(t) + + def pols_if_at(self, t, timing): + if timing == "BEF_MAT": + if t == 0: + return self.pols_if_init() + else: + return self.pols_if_at(t-1, "BEF_DECR") - self.pols_lapse(t-1) - self.pols_death(t-1) + elif timing == "BEF_NB": + return self.pols_if_at(t, "BEF_MAT") - self.pols_maturity(t) + elif timing == "BEF_DECR": + return self.pols_if_at(t, "BEF_NB") + self.pols_new_biz(t) + else: + raise ValueError("invalid timing") + + def pols_if_init(self): + return torch.where(self.duration_mth(0) > 0, self.mp.policy_count, 0) + + def pols_lapse(self, t): + return (self.pols_if_at(t, "BEF_DECR") - self.pols_death(t)) * (1-(1 - self.lapse_rate(t))**(1/12)) + + def pols_maturity(self, t): + return (self.duration_mth(t) == self.mp.policy_term * 12) * self.pols_if_at(t, "BEF_MAT") + + def pols_new_biz(self, t): + return torch.where(self.duration_mth(t) == 0, self.mp.policy_count, 0) + + def premiums(self, t): + return self.mp.premium_pp * self.pols_if_at(t, "BEF_DECR") + + +def run_recursive_model(model: TermME): + model.cache_graph._caches = defaultdict(dict) + model.cache_graph._caches_agg = defaultdict(dict) + model.RunModel(model.mp.max_proj_len) + return float(sum(model.cache_agg['discounted_net_cf'].values())) + + +def time_recursive_GPU(model: TermME): + model.OptimizeMemoryAndReset() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + result = run_recursive_model(model) + end.record() + torch.cuda.synchronize() + return result, start.elapsed_time(end) / 1000 + +def time_recursive_CPU(model: TermME): + model.OptimizeMemoryAndReset() + start = timeit.default_timer() + result = run_recursive_model(model) + end = timeit.default_timer() + return result, end - start + +def time_recursive_PyTorch(multiplier: int): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + print(f"{device=}") + + with device: + mp = ModelPoints(model_point_table, premium_table) + mp_multiplied = ModelPoints(model_point_table, premium_table, multiplier) + assume = Assumptions(disc_rate_ann, mort_table) + model = TermME(mp, assume) + + if device.type == 'cuda': + time_recursive = time_recursive_GPU + else: + time_recursive = time_recursive_CPU + run_recursive_model(model) # warm up, generate dependency graph + model.mp = mp_multiplied + result, time_in_seconds = time_recursive(model) + # report results + print("PyTorch recursive model") + print(f"number modelpoints={len(mp_multiplied.duration_mth):,}") + print(f"{result=:,}") + print(f"{time_in_seconds=}") \ No newline at end of file From 2f2f7c3b0eb40304cfd2386a3aa6373733d3098a Mon Sep 17 00:00:00 2001 From: m Date: Sat, 11 May 2024 14:47:29 +0000 Subject: [PATCH 3/4] try to fix CI docker build --- containers/BasicTerm_ME_python/Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/containers/BasicTerm_ME_python/Dockerfile b/containers/BasicTerm_ME_python/Dockerfile index 9f50551..682e966 100644 --- a/containers/BasicTerm_ME_python/Dockerfile +++ b/containers/BasicTerm_ME_python/Dockerfile @@ -4,9 +4,9 @@ FROM pytorch/pytorch:2.2.2-cuda11.8-cudnn8-runtime # Set the working directory in the container WORKDIR /app -RUN python -m pip install --upgrade pip -RUN python -m pip install --upgrade "jax[cuda12]" -RUN python -m pip install \ +RUN pip install --upgrade pip +RUN pip install --upgrade "jax[cuda12]" +RUN pip install \ pandas \ openpyxl \ equinox \ From f9b6510c38e855e00bb6a8ae7b51423c940d1b39 Mon Sep 17 00:00:00 2001 From: m Date: Sat, 11 May 2024 16:36:29 +0000 Subject: [PATCH 4/4] compiled results in README --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b59d40d..b89fbfd 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,10 @@ We currently only have 1 benchmark, we are working to expand the benchmarks. Ope Time measurement is currently the best of 3 runs. -| benchmark | recursive | container |A100-SXM4-40GB | H100-SXM5-80GB | +| benchmark | classification | container |A100-SXM4-40GB | H100-SXM5-80GB | |---------------|-|-|----------------|----------------| -| BasicTerm_ME 100 Million | Yes | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_pytorch/general) | 15.8194s | 7.1820s | +| BasicTerm_ME 100 Million | recursive PyTorch | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 15.8284s | 7.205s | +| BasicTerm_ME 100 Million | compiled iterative JAX | [link](https://hub.docker.com/repository/docker/actuarial/basicterm_me_python/general) | 3.448s | 1.551s | ### Notes