Skip to content

Commit c52f7e0

Browse files
authored
ENH: compute decay rates from sub-intensities (#393)
* DOC: do not reverse resonance names * DOC: rotate decay rate matrix visualization * DX: test values of decay rate matrix * ENH: compute sum of rate matrix with numpy.tril
1 parent 1d9fd48 commit c52f7e0

File tree

2 files changed

+41
-40
lines changed

2 files changed

+41
-40
lines changed

.cspell/python-extra.txt

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ svgutils
2626
sympify
2727
tensorwaves
2828
tolist
29+
tril
2930
viewboxing
3031
viridis
3132
vmax

docs/intensity.ipynb

+40-40
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,20 @@
533533
},
534534
"outputs": [],
535535
"source": [
536-
"def compute_decay_rates(func, integration_sample: DataSample):\n",
536+
"def compute_decay_rates(rates: np.ndarray) -> np.ndarray:\n",
537+
" m, n = rates.shape\n",
538+
" assert m == n\n",
539+
" d = rates.diagonal()\n",
540+
" D = d * np.identity(n)\n",
541+
" X = d[None] + d[None].T\n",
542+
" return rates - X + 2 * D\n",
543+
"\n",
544+
"\n",
545+
"def compute_sub_intensities(func, integration_sample: DataSample):\n",
537546
" decay_rates = np.zeros(shape=(n_resonances, n_resonances))\n",
538547
" combinations = list(product(enumerate(resonances), enumerate(resonances)))\n",
539548
" progress_bar = tqdm(\n",
540-
" desc=\"Calculating rate matrix\",\n",
549+
" desc=\"Calculating sub-intensities\",\n",
541550
" disable=NO_LOG,\n",
542551
" total=(len(combinations) + n_resonances) // 2,\n",
543552
" )\n",
@@ -548,10 +557,7 @@
548557
" progress_bar.postfix = f\"{resonance1.name} × {resonance2.name}\"\n",
549558
" res1 = resonance1.latex\n",
550559
" res2 = resonance2.latex\n",
551-
" if res1 == res2:\n",
552-
" I_sub = sub_intensity(func, integration_sample, non_zero_couplings=[res1])\n",
553-
" else:\n",
554-
" I_sub = interference_intensity(func, integration_sample, [res1], [res2])\n",
560+
" I_sub = sub_intensity(func, integration_sample, non_zero_couplings=[res1, res2])\n",
555561
" decay_rates[i, j] = I_sub / I_tot\n",
556562
" if i != j:\n",
557563
" decay_rates[j, i] = decay_rates[i, j]\n",
@@ -568,7 +574,6 @@
568574
"resonances = sorted(\n",
569575
" (chain.resonance for chain in model.decay.chains),\n",
570576
" key=sort_resonances,\n",
571-
" reverse=True,\n",
572577
")\n",
573578
"n_resonances = len(resonances)"
574579
]
@@ -597,70 +602,60 @@
597602
" fig, ax = plt.subplots(figsize=(9, 9))\n",
598603
" fig.patch.set_color(\"none\")\n",
599604
" ax.set_title(title)\n",
600-
" ax.matshow(jnp.rot90(decay_rates).T, cmap=plt.cm.coolwarm, vmin=-vmax, vmax=+vmax)\n",
605+
" ax.matshow(decay_rates, cmap=plt.cm.coolwarm, vmin=-vmax, vmax=+vmax)\n",
601606
"\n",
602607
" resonance_latex = [f\"${p.latex}$\" for p in resonances]\n",
603608
" ax.set_xticks(range(n_resonances))\n",
604-
" ax.set_xticklabels(reversed(resonance_latex), rotation=90)\n",
609+
" ax.set_xticklabels(resonance_latex, rotation=90)\n",
610+
" ax.xaxis.tick_bottom()\n",
605611
" ax.set_yticks(range(n_resonances))\n",
606612
" ax.set_yticklabels(resonance_latex)\n",
607613
" for i in range(n_resonances):\n",
608614
" for j in range(n_resonances):\n",
609-
" if j < i:\n",
615+
" if i > j:\n",
610616
" continue\n",
611617
" rate = decay_rates[i, j]\n",
612-
" ax.text(\n",
613-
" n_resonances - j - 1,\n",
614-
" i,\n",
615-
" f\"{100 * rate:.2f}\",\n",
616-
" va=\"center\",\n",
617-
" ha=\"center\",\n",
618-
" )\n",
618+
" ax.text(i, j, f\"{100 * rate:.2f}\", ha=\"center\", va=\"center\")\n",
619619
" fig.tight_layout()\n",
620620
" return fig\n",
621621
"\n",
622622
"\n",
623-
"decay_rates = compute_decay_rates(intensity_func, integration_sample)\n",
623+
"sub_intensities = compute_sub_intensities(intensity_func, integration_sample)\n",
624+
"decay_rates = compute_decay_rates(sub_intensities)\n",
624625
"fig = visualize_decay_rates(decay_rates)\n",
625626
"output_path = \"_images/rate-matrix.svg\"\n",
626627
"fig.savefig(output_path, bbox_inches=\"tight\")\n",
627628
"reduce_svg_size(output_path)\n",
628629
"plt.show(fig)"
629630
]
630631
},
631-
{
632-
"cell_type": "markdown",
633-
"metadata": {},
634-
"source": [
635-
":::{only} latex\n",
636-
"{{ FIG_RATE_MATRIX }}\n",
637-
":::"
638-
]
639-
},
640632
{
641633
"cell_type": "code",
642634
"execution_count": null,
643635
"metadata": {
644636
"jupyter": {
645637
"source_hidden": true
646638
},
647-
"mystnb": {
648-
"code_prompt_show": "Function for computing the total over all rates"
649-
},
650639
"tags": [
651-
"hide-cell"
640+
"hide-input"
652641
]
653642
},
654643
"outputs": [],
655644
"source": [
656-
"def compute_sum_over_decay_rates(decay_rate_matrix) -> float:\n",
657-
" decay_rate_sum = 0.0\n",
658-
" for i in range(len(resonances)):\n",
659-
" for j in range(len(resonances)):\n",
660-
" if j < i:\n",
661-
" continue\n",
662-
" decay_rate_sum += decay_rate_matrix[i, j]\n",
663-
" return decay_rate_sum"
645+
"np.testing.assert_array_almost_equal(\n",
646+
" 100 * decay_rates[-1],\n",
647+
" [4.78, 0.16, -1.68, 0.04, 0.03, -7.82, 5.09, 1.96, -1.15, -1.95, 0.04, 14.7],\n",
648+
" decimal=2,\n",
649+
")"
650+
]
651+
},
652+
{
653+
"cell_type": "markdown",
654+
"metadata": {},
655+
"source": [
656+
":::{only} latex\n",
657+
"{{ FIG_RATE_MATRIX }}\n",
658+
":::"
664659
]
665660
},
666661
{
@@ -669,6 +664,10 @@
669664
"metadata": {},
670665
"outputs": [],
671666
"source": [
667+
"def compute_sum_over_decay_rates(rate_matrix: np.ndarray) -> float:\n",
668+
" return rate_matrix.diagonal().sum() + np.tril(rate_matrix, k=-1).sum()\n",
669+
"\n",
670+
"\n",
672671
"np.testing.assert_almost_equal(compute_sum_over_decay_rates(decay_rates), 1.0)"
673672
]
674673
},
@@ -811,7 +810,8 @@
811810
" y_range=sub_region_y_range,\n",
812811
")\n",
813812
"sub_sample = transformer(sub_sample)\n",
814-
"sub_decay_rates = compute_decay_rates(intensity_func, sub_sample)\n",
813+
"sub_decay_intensities = compute_sub_intensities(intensity_func, sub_sample)\n",
814+
"sub_decay_rates = compute_decay_rates(sub_decay_intensities)\n",
815815
"fig = visualize_decay_rates(sub_decay_rates, title=\"Rate matrix over sub-region\")\n",
816816
"output_path = \"_images/rate-matrix-sub-region.svg\"\n",
817817
"fig.savefig(output_path, bbox_inches=\"tight\")\n",

0 commit comments

Comments
 (0)