-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding high level plotting API (#128)
* adding high level plotting api * adding intermediate implementation * adding user ignore to git ignore * removing unit test diffs * revert custom changes * small fixes * adding simple discriminant functions * fixing circular import * updating to a new version of tagger class * updating tagger class and its tests * adding a more general discriminant calculation function * updating tagger plotting * update requirements in package install * fixing folder for plotting * fixing darglint * darglint fix * Update puma/hlplots/results.py Co-authored-by: Alexander Froch <[email protected]> * fix linting * Update puma/hlplots/results.py Co-authored-by: Alexander Froch <[email protected]> * improve formatting * Update puma/hlplots/tagger.py Co-authored-by: Joschka Birk <[email protected]> * adding catch for inf values in discriminant calculation * adding small epsilon to discriminant calculation * removing inheritance from Tagger and Results classes * fixing typos * fixing doc string * adding improvements to hlevel api * small gitignore improvement * fixing hard coded values * fixing pylint issues * fix darglint issue * adding changelog * added warning in code when using 2 data frames * adding high level API to docs * Update examples/high_level_plots.py Co-authored-by: Joschka Birk <[email protected]> Co-authored-by: Alexander Froch <[email protected]> Co-authored-by: Joschka Birk <[email protected]>
- Loading branch information
1 parent
c4b4cac
commit 4a6613f
Showing
20 changed files
with
1,397 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# High level API | ||
|
||
To set up the inputs for the plots, have a look [here](./index.md). | ||
|
||
The following examples use the dummy data which is described [here](./dummy_data.md) | ||
|
||
All the previous examples show how to use the plotting of individual plots often requiring | ||
a fair amount of code to produce ROC curves etc. | ||
|
||
This high level API facilitates several steps and is designed to quickly plot b- and c-jet | ||
performance plots. | ||
|
||
|
||
## Initialising the taggers | ||
|
||
```py | ||
§§§examples/high_level_plots.py:1:55§§§ | ||
``` | ||
WARNING: when using 2 different data frames you cannot just use one `tagger_args` but you need | ||
as many as you have data frames defining the flavour classes and performance variables. | ||
|
||
|
||
## Discriminant plots | ||
To plot the discriminant, you can now simply call one function and everything else is handled automatically, | ||
here for the _b_-jet discriminant | ||
```py | ||
§§§examples/high_level_plots.py:56:58§§§ | ||
``` | ||
|
||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_disc_b.png width=500> | ||
|
||
and similar for the _c_-jet discriminant | ||
```py | ||
§§§examples/high_level_plots.py:59§§§ | ||
``` | ||
|
||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_disc_c.png width=500> | ||
|
||
|
||
## ROC plots | ||
|
||
In the same manner you can plot ROC curves, here for the _b_-tagging performance | ||
```py | ||
§§§examples/high_level_plots.py:62:64§§§ | ||
``` | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_roc_b.png width=500> | ||
|
||
and similar for the _c_-tagging performance | ||
```py | ||
§§§examples/high_level_plots.py:65§§§ | ||
``` | ||
|
||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_roc_c.png width=500> | ||
|
||
|
||
## Performance vs a variable | ||
In this case we plot the performance as a function of the jet pT with the same syntax as above | ||
```py | ||
§§§examples/high_level_plots.py:69:82§§§ | ||
``` | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_pt_b_eff.png width=500> | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_pt_c_rej.png width=500> | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_pt_light_rej.png width=500> | ||
|
||
and similar for the _c_-tagging performance | ||
```py | ||
§§§examples/high_level_plots.py:84:94§§§ | ||
``` | ||
|
||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_fixed_per_bin_pt_b_eff.png width=500> | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_fixed_per_bin_pt_c_rej.png width=500> | ||
<img src=https://github.com/umami-hep/puma/raw/examples-material/hlplots_dummy_tagger_fixed_per_bin_pt_light_rej.png width=500> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Produce roc curves from tagger output and labels.""" | ||
# from pathlib import Path | ||
|
||
# import h5py | ||
import numpy as np | ||
|
||
from puma.hlplots import Results, Tagger | ||
from puma.utils import get_dummy_2_taggers, logger | ||
|
||
# The line below generates dummy data which is similar to a NN output | ||
df = get_dummy_2_taggers(add_pt=True) | ||
class_ids = [0, 4, 5] | ||
# Remove all jets which are not trained on | ||
df.query(f"HadronConeExclTruthLabelID in {class_ids}", inplace=True) | ||
df.query("pt < 250e3", inplace=True) | ||
|
||
logger.info("Start plotting") | ||
|
||
# WARNING: if you use 2 different data frames you need to specify the `is_light`, | ||
# `is_c` and `is_b` for each data frame separately and thus you cannot use these | ||
# args for each tagger the same applies to the `perf_var` | ||
tagger_args = { | ||
"perf_var": df["pt"] / 1e3, | ||
"is_light": df["HadronConeExclTruthLabelID"] == 0, | ||
"is_c": df["HadronConeExclTruthLabelID"] == 4, | ||
"is_b": df["HadronConeExclTruthLabelID"] == 5, | ||
} | ||
|
||
|
||
dips = Tagger("dips", template=tagger_args) | ||
dips.label = "dummy DIPS ($f_{c}=0.005$)" | ||
dips.f_c = 0.005 | ||
dips.f_b = 0.04 | ||
dips.colour = "#AA3377" | ||
dips.extract_tagger_scores(df) | ||
|
||
rnnip = Tagger("rnnip", template=tagger_args) | ||
rnnip.label = "dummy RNNIP ($f_{c}=0.07$)" | ||
rnnip.f_c = 0.07 | ||
rnnip.f_b = 0.04 | ||
rnnip.colour = "#4477AA" | ||
rnnip.reference = True | ||
rnnip.extract_tagger_scores(df) | ||
|
||
|
||
results = Results() | ||
results.add(dips) | ||
results.add(rnnip) | ||
|
||
|
||
results.sig_eff = np.linspace(0.6, 0.95, 20) | ||
results.atlas_second_tag = ( | ||
"$\\sqrt{s}=13$ TeV, dummy jets \n$t\\bar{t}$, $20$ GeV $< p_{T} <250$ GeV" | ||
) | ||
|
||
# tagger discriminant plots | ||
logger.info("Plotting tagger discriminant plots.") | ||
results.plot_discs("hlplots_disc_b.png") | ||
results.plot_discs("hlplots_disc_c.png", signal_class="cjets") | ||
|
||
|
||
logger.info("Plotting ROC curves.") | ||
# ROC curves as a function of the b-jet efficiency | ||
results.plot_rocs("hlplots_roc_b.png") | ||
# ROC curves as a function of the c-jet efficiency | ||
results.plot_rocs("hlplots_roc_c.png", signal_class="cjets") | ||
|
||
|
||
logger.info("Plotting efficiency/rejection vs pT curves.") | ||
# eff/rej vs. variable plots | ||
results.atlas_second_tag = "$\\sqrt{s}=13$ TeV, dummy jets \n$t\\bar{t}$\n70% WP" | ||
# you can either specify a WP per tagger | ||
# dips.working_point = 0.7 | ||
# rnnip.working_point = 0.7 | ||
# or alternatively also pass the argument `working_point` to the plot_var_perf function. | ||
# to specify the `disc_cut` per tagger is also possible. | ||
results.plot_var_perf( | ||
plot_name="hlplots_dummy_tagger", | ||
working_point=0.7, | ||
bins=[20, 30, 40, 60, 85, 110, 140, 175, 250], | ||
fixed_eff_bin=False, | ||
) | ||
|
||
results.atlas_second_tag = ( | ||
"$\\sqrt{s}=13$ TeV, dummy jets \n$t\\bar{t}$\n70% WP per bin" | ||
) | ||
results.plot_var_perf( | ||
plot_name="hlplots_dummy_tagger_fixed_per_bin", | ||
bins=[20, 30, 40, 60, 85, 110, 140, 175, 250], | ||
fixed_eff_bin=True, | ||
working_point=0.7, | ||
h_line=0.7, | ||
disc_cut=None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""High level plotting API within puma, to avoid code duplication.""" | ||
# flake8: noqa | ||
# pylint: skip-file | ||
|
||
from puma.hlplots.results import Results | ||
from puma.hlplots.tagger import Tagger |
Oops, something went wrong.