Skip to content

Commit c247bf4

Browse files
authored
RooFitter: Add protections, type hints, docstrings (#1021)
1 parent 8d4fbb9 commit c247bf4

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

machine_learning_hep/fitting/roofitter.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
## along with this program. if not, see <https://www.gnu.org/licenses/>. ##
1313
#############################################################################
1414

15+
"""Definition of the RooFitter class and helper functions"""
16+
1517
from math import sqrt
1618

1719
import ROOT
@@ -22,25 +24,34 @@
2224
# pylint: disable=too-few-public-methods, too-many-statements
2325
# (temporary until we add more functionality)
2426
class RooFitter:
27+
"""Fitter using Roofit for combined fits of invariant-mass distributions"""
28+
2529
def __init__(self):
2630
ROOT.gErrorIgnoreLevel = ROOT.kError
2731
ROOT.RooMsgService.instance().setSilentMode(True)
2832
ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.WARNING)
2933
ROOT.RooMsgService.instance().setGlobalKillBelow(ROOT.RooFit.ERROR)
3034

31-
def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
35+
def fit_mass_new(
36+
self, hist, pdfnames: dict, fit_spec: dict, level: str, roows: ROOT.RooWorkspace = None, plot: bool = False
37+
):
38+
"""New fit method"""
3239
if hist.GetEntries() == 0:
3340
raise UserWarning("Cannot fit histogram with no entries")
3441
ws = roows or ROOT.RooWorkspace("ws")
3542
var_m = fit_spec.get("var", "m")
3643

37-
n_signal = RooRealVar("n_signal", "Number of signal events", 1e7, 0, 1.e10)
44+
n_signal = RooRealVar("n_signal", "Number of signal events", 1e7, 0, 1e10)
3845
n_background = RooRealVar("n_background", "Number of background events", 1e7, 0, 1e10)
3946

47+
model = None
4048
for comp, spec in fit_spec.get("components", {}).items():
4149
fn = ws.factory(spec["fn"])
4250
if comp == "model":
4351
model = fn
52+
if model is None:
53+
raise ValueError("model not set")
54+
4455
m = ws.var(var_m)
4556

4657
if level == "data" and USE_EXTMODEL:
@@ -54,8 +65,6 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
5465
"model", "Total model", RooArgList(signal_pdf, background_pdf), RooArgList(n_signal, n_background)
5566
)
5667

57-
# if range_m := fit_spec.get('range'):
58-
# m.setRange(range_m[0], range_m[1])
5968
dh = ROOT.RooDataHist("dh", "dh", [m], Import=hist)
6069
if range_m := fit_spec.get("range"):
6170
m.setRange("fit", *range_m)
@@ -64,7 +73,9 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
6473
if level == 'data' and USE_EXTMODEL:
6574
for v in ws.allVars():
6675
v.setConstant(True)
67-
res = extmodel.fitTo(dh, Range=(range_m[0], range_m[1]), Save=True, PrintLevel=-1, Strategy=1, MaxCalls=5000)
76+
res = extmodel.fitTo(
77+
dh, Range=(range_m[0], range_m[1]), Save=True, PrintLevel=-1, Strategy=1, MaxCalls=5000
78+
)
6879
else:
6980
res = model.fitTo(dh, Save=True, PrintLevel=-1, Strategy=1, MaxCalls=5000)
7081
if level == 'data' and USE_EXTMODEL:
@@ -83,7 +94,8 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
8394
model.paramOn(frame, Layout=(0.65, 1.0, 0.9))
8495
frame.getAttText().SetTextFont(42)
8596
frame.getAttText().SetTextSize(0.001)
86-
frame.SetAxisRange(range_m[0], range_m[1], "X")
97+
if range_m:
98+
frame.SetAxisRange(range_m[0], range_m[1], "X")
8799
frame.SetAxisRange(0.0, frame.GetMaximum() + (frame.GetMaximum() * 0.3), "Y")
88100

89101
try:
@@ -99,8 +111,7 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
99111
)
100112
# model.SetName("bkg")
101113
model.plotOn(frame, ROOT.RooFit.Name("model"))
102-
# pylint: disable=bare-except
103-
except:
114+
except: # pylint: disable=bare-except # noqa: E722
104115
pass
105116
# for comp in fit_spec.get('components', {}):
106117
# if comp != 'model':
@@ -109,7 +120,7 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
109120
# c.Modified()
110121
# c.Update()
111122

112-
if level == "data" and USE_EXTMODEL:
123+
if level == "data" and USE_EXTMODEL and frame is not None:
113124
residuals = frame.residHist("data", "pdf_bkg")
114125
residual_frame = m.frame()
115126
residual_frame.addPlotable(residuals, "P")
@@ -123,19 +134,26 @@ def fit_mass_new(self, hist, pdfnames, fit_spec, level, roows=None, plot=False):
123134
ROOT.RooFit.Normalization(1.0, ROOT.RooAbsReal.RelativeExpected),
124135
)
125136

126-
residual_frame.SetAxisRange(range_m[0], range_m[1], "X")
137+
if range_m:
138+
residual_frame.SetAxisRange(range_m[0], range_m[1], "X")
127139
residual_frame.SetYTitle("Residuals")
128140

129141
return (res, ws, frame, residual_frame)
130142

131143
def fit_mass(self, hist, fit_spec, plot=False):
144+
"""Old fit method"""
132145
if hist.GetEntries() == 0:
133146
raise UserWarning("Cannot fit histogram with no entries")
134147
ws = ROOT.RooWorkspace("ws")
148+
149+
model = None
135150
for comp, spec in fit_spec.get("components", {}).items():
136151
ws.factory(spec["fn"])
137152
if comp == "sum":
138153
model = ws.pdf(comp)
154+
if model is None:
155+
raise ValueError("model not set")
156+
139157
m = ws.var("m")
140158
# m.setRange('full', 0., 3.)
141159
dh = ROOT.RooDataHist("dh", "dh", [m], Import=hist)
@@ -154,6 +172,7 @@ def fit_mass(self, hist, fit_spec, plot=False):
154172

155173

156174
def calc_signif(roows, res, pdfnames, param_names, mean_sgn, sigma_sgn):
175+
"""Calculate significance, signal, background, signal/background ratio."""
157176
if not USE_EXTMODEL:
158177
return (0., 0., 0., 0., 0., 0, 0, 0.)
159178
f_sig = roows.pdf(pdfnames["pdf_sig"])
@@ -217,6 +236,7 @@ def calc_signif(roows, res, pdfnames, param_names, mean_sgn, sigma_sgn):
217236

218237

219238
def create_text_info(x_1, y_1, x_2, y_2):
239+
"""Create an info box for fit plots and set its style."""
220240
text_info = TPaveText(x_1, y_1, x_2, y_2, "NDC")
221241
text_info.SetBorderSize(0)
222242
text_info.SetFillColor(0) # Transparent fill
@@ -230,6 +250,7 @@ def create_text_info(x_1, y_1, x_2, y_2):
230250

231251

232252
def add_text_info_fit(text_info, frame, roows, param_names):
253+
"""Add fit info on the info box."""
233254
chi2 = frame.chiSquare()
234255
mean_sgn = roows.var(param_names["gauss_mean"])
235256
sigma_sgn = roows.var(param_names["gauss_sigma"])
@@ -253,6 +274,7 @@ def add_text_info_fit(text_info, frame, roows, param_names):
253274

254275

255276
def add_text_info_perf(text_info, sig, sig_err, bkg, bkg_err, s_over_b, s_over_b_err, signif, signif_err):
277+
"""Add signal, background, signal/background and significance on the info box."""
256278
text_info.AddText(f"S(3#sigma) = {sig:.0f} #pm {sig_err:.0f}")
257279
text_info.AddText(f"B(3#sigma) = {bkg:.0f} #pm {bkg_err:.0f}")
258280
text_info.AddText(f"S/B(3#sigma) = {s_over_b:.3f} #pm {s_over_b_err:.3f}")

0 commit comments

Comments
 (0)