Skip to content

Commit

Permalink
[RF] Implement RooFit::DataError() pythonization in C++
Browse files Browse the repository at this point in the history
This reduces divergence between the Python and C++ interface.

Tested by the tutorials where this string-to-enum pythonization is used.
  • Loading branch information
guitargeek committed Dec 2, 2024
1 parent ae4b0b7 commit b36dc4f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,44 +194,21 @@ def Link(*args, **kwargs):

@cpp_signature("RooFit::DataError(Int_t) ;")
def DataError(etype):
r"""Instead of passing an enum value to this function, you can pass a
string with the name of that enum value, for example:
~~~ {.py}
data.plotOn(frame, DataError="SumW2")
# instead of DataError=ROOT.RooAbsData.SumW2
~~~
If you want to use the `"None"` enum value to disable error plotting, you
r"""If you want to use the `"None"` enum value to disable error plotting, you
can also pass `None` directly instead of passing a string:
~~~ {.py}
data.plotOn(frame, DataError=None)
# instead of DataError="None"
~~~
"""
# Redefinition of `DataError` to also accept `str` or `NoneType` to get the
# corresponding enum values from RooAbsData.DataError.
from cppyy.gbl import RooFit

# One of the possible enum values is "None", and we want the user to be
# able to pass None also as a NoneType for convenience.
if etype is None:
etype = "None"

if isinstance(etype, str):
try:
import ROOT

etype = getattr(ROOT.RooAbsData.ErrorType, etype)
except AttributeError as error:
raise ValueError(
"Unsupported error type type passed to DataError()."
+ ' Supported decay types are : "Poisson", "SumW2", "Auto", "Expected", and None.'
)
except Exception as exception:
raise exception

return RooFit._DataError(etype)


Expand Down
2 changes: 2 additions & 0 deletions roofit/roofitcore/inc/RooAbsData.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class RooAbsData : public TNamed, public RooPrintable {
virtual double weightSquared() const = 0 ; // DERIVED

enum ErrorType { Poisson, SumW2, None, Auto, Expected } ;
static ErrorType errorTypeFromString(std::string const &name);

/// Return the symmetric error on the current weight.
/// See also weightError(double&,double&,ErrorType) const for asymmetric errors.
// \param[in] etype Type of error to compute. May throw if not supported.
Expand Down
1 change: 1 addition & 0 deletions roofit/roofitcore/inc/RooGlobalFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ RooCmdArg EventRange(Int_t nStart, Int_t nStop) ;
// RooChi2Var::ctor / RooNLLVar arguments
RooCmdArg Extended(bool flag=true) ;
RooCmdArg DataError(Int_t) ;
RooCmdArg DataError(std::string const&) ;
RooCmdArg NumCPU(Int_t nCPU, Int_t interleave=0) ;
RooCmdArg Parallelize(int nWorkers) ;
RooCmdArg ModularL(bool flag=false) ;
Expand Down
24 changes: 24 additions & 0 deletions roofit/roofitcore/src/RooAbsData.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ observable snapshots are stored in the dataset.

#include <iostream>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <unordered_map>


ClassImp(RooAbsData);
Expand Down Expand Up @@ -2640,3 +2643,24 @@ TH2F *RooAbsData::createHistogram(const RooAbsRealLValue &var1, const RooAbsReal

return histogram;
}

////////////////////////////////////////////////////////////////////////////////
/// Convert a string to the value of the RooAbsData::ErrorType enum with the
/// same name.
RooAbsData::ErrorType RooAbsData::errorTypeFromString(std::string const &name)
{
using Map = std::unordered_map<std::string, RooAbsData::ErrorType>;
static Map enumMap{{"Poisson", RooAbsData::Poisson},
{"SumW2", RooAbsData::SumW2},
{"None", RooAbsData::None},
{"Auto", RooAbsData::Auto},
{"Expected", RooAbsData::Expected}};
auto found = enumMap.find(name);
if (found == enumMap.end()) {
std::stringstream msg;
msg << "Unsupported error type type passed to DataError(). "
"Supported decay types are : \"Poisson\", \"SumW2\", \"Auto\", \"Expected\", and None.";
throw std::invalid_argument(msg.str());
}
return found->second;
}
7 changes: 6 additions & 1 deletion roofit/roofitcore/src/RooGlobalFunc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <RooGlobalFunc.h>

#include <RooAbsData.h>
#include <RooAbsPdf.h>
#include <RooCategory.h>
#include <RooDataHist.h>
Expand Down Expand Up @@ -469,7 +470,11 @@ RooCmdArg Extended(bool flag)
}
RooCmdArg DataError(Int_t etype)
{
return RooCmdArg("DataError", (Int_t)etype, 0, 0, 0, nullptr, nullptr, nullptr, nullptr);
return RooCmdArg("DataError", etype);
}
RooCmdArg DataError(std::string const &etype)
{
return DataError(RooAbsData::errorTypeFromString(etype));
}
RooCmdArg NumCPU(Int_t nCPU, Int_t interleave)
{
Expand Down

0 comments on commit b36dc4f

Please sign in to comment.