Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ Questions
Is it possible to set the backend from the CLI?
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Not at the moment. `Pull Requests <https://github.com/scikit-hep/pyhf/compare>`__ are welcome.

See also:
- :issue:`266`
Yes.
Use the :code:`--backend` option for :code:`pyhf cls` to specify a tensor backend.
The default backend is NumPy.
For more information see :code:`pyhf cls --help`.

Troubleshooting
---------------
Expand Down
32 changes: 29 additions & 3 deletions src/pyhf/cli/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..utils import hypotest, EqDelimStringParamType
from ..workspace import Workspace
from .. import tensorlib, set_backend, optimize
from .. import tensor, get_backend, set_backend, optimize

logging.basicConfig()
log = logging.getLogger(__name__)
Expand All @@ -27,10 +27,24 @@ def cli():
@click.option('-p', '--patch', multiple=True)
@click.option('--testpoi', default=1.0)
@click.option('--teststat', type=click.Choice(['q', 'qtilde']), default='qtilde')
@click.option(
'--backend',
type=click.Choice(['numpy', 'pytorch', 'tensorflow', 'np', 'torch', 'tf']),
help='The tensor backend used for the calculation.',
default='numpy',
)
@click.option('--optimizer')
@click.option('--optconf', type=EqDelimStringParamType(), multiple=True)
def cls(
workspace, output_file, measurement, patch, testpoi, teststat, optimizer, optconf
workspace,
output_file,
measurement,
patch,
testpoi,
teststat,
backend,
optimizer,
optconf,
):
with click.open_file(workspace, 'r') as specstream:
spec = json.load(specstream)
Expand All @@ -49,6 +63,15 @@ def cls(
},
)

# set the backend if not NumPy
if backend in ['pytorch', 'torch']:
set_backend(tensor.pytorch_backend())
elif backend in ['tensorflow', 'tf']:
from tensorflow.compat.v1 import Session

set_backend(tensor.tensorflow_backend(session=Session()))
tensorlib, _ = get_backend()

optconf = {k: v for item in optconf for k, v in item.items()}

# set the new optimizer
Expand All @@ -59,7 +82,10 @@ def cls(
result = hypotest(
testpoi, ws.data(model), model, qtilde=is_qtilde, return_expected_set=True
)
result = {'CLs_obs': result[0].tolist()[0], 'CLs_exp': result[-1].ravel().tolist()}
result = {
'CLs_obs': tensorlib.tolist(result[0])[0],
'CLs_exp': tensorlib.tolist(tensorlib.reshape(result[-1], [-1])),
}

if output_file is None:
click.echo(json.dumps(result, indent=4, sort_keys=True))
Expand Down
12 changes: 10 additions & 2 deletions src/pyhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,22 @@ def pvals_from_teststat(sqrtqmu_v, sqrtqmuA_v, qtilde=False):
nullval = sqrtqmu_v
altval = -(sqrtqmuA_v - sqrtqmu_v)
else: # qtilde
if sqrtqmu_v < sqrtqmuA_v:

def _true_case():
nullval = sqrtqmu_v
altval = -(sqrtqmuA_v - sqrtqmu_v)
else:
return nullval, altval

def _false_case():
qmu = tensorlib.power(sqrtqmu_v, 2)
qmu_A = tensorlib.power(sqrtqmuA_v, 2)
nullval = (qmu + qmu_A) / (2 * sqrtqmuA_v)
altval = (qmu - qmu_A) / (2 * sqrtqmuA_v)
return nullval, altval

nullval, altval = tensorlib.conditional(
(sqrtqmu_v < sqrtqmuA_v)[0], _true_case, _false_case
)
CLsb = 1 - tensorlib.normal_cdf(nullval)
CLb = 1 - tensorlib.normal_cdf(altval)
CLs = CLsb / CLb
Expand Down
20 changes: 20 additions & 0 deletions tests/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ def test_import_prepHistFactory_and_cls(tmpdir, script_runner):
assert 'CLs_exp' in d


@pytest.mark.parametrize(
"backend", ["numpy", "tensorflow", "pytorch"],
)
def test_cls_backend_option(tmpdir, script_runner, backend):
temp = tmpdir.join("parsed_output.json")
command = 'pyhf xml2json validation/xmlimport_input/config/example.xml --basedir validation/xmlimport_input/ --output-file {0:s}'.format(
temp.strpath
)
ret = script_runner.run(*shlex.split(command))

command = 'pyhf cls --backend {0:s} {1:s}'.format(backend, temp.strpath)
ret = script_runner.run(*shlex.split(command))

assert ret.success
d = json.loads(ret.stdout)
assert d
assert 'CLs_obs' in d
assert 'CLs_exp' in d


def test_import_and_export(tmpdir, script_runner):
temp = tmpdir.join("parsed_output.json")
command = 'pyhf xml2json validation/xmlimport_input/config/example.xml --basedir validation/xmlimport_input/ --output-file {0:s}'.format(
Expand Down