Skip to content

Commit f8033ee

Browse files
authored
Iteration limit in smart sampling to fix behavior for step functions (#928)
Closes #923 The smart sampling algorithm was not tested on step functions and was broken when applied to such functions. The fix is to stop iterating when an interval becomes too narrow. In addition, further stopping criteria were added, based on number of iterations and time spend.
1 parent 47bd39c commit f8033ee

File tree

3 files changed

+98
-15
lines changed

3 files changed

+98
-15
lines changed

src/iminuit/util.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,18 +1542,32 @@ def _histogram_segments(mask, xe, masked):
15421542
return segments
15431543

15441544

1545-
def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3):
1545+
def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3, maxiter=20, maxtime=10):
1546+
t0 = monotonic()
15461547
x = np.linspace(xmin, xmax, start)
15471548
ynew = f(x)
15481549
ymin = np.min(ynew)
15491550
ymax = np.max(ynew)
15501551
y = {xi: yi for (xi, yi) in zip(x, ynew)}
15511552
a = x[:-1]
15521553
b = x[1:]
1554+
niter = 0
15531555
while len(a):
1554-
if len(y) > 10000:
1555-
warnings.warn("Too many points", RuntimeWarning) # pragma: no cover
1556-
break # pragma: no cover
1556+
niter += 1
1557+
if niter > maxiter:
1558+
msg = (
1559+
f"Iteration limit {maxiter} in smart sampling reached, "
1560+
f"produced {len(y)} points"
1561+
)
1562+
warnings.warn(msg, RuntimeWarning)
1563+
break
1564+
if monotonic() - t0 > maxtime:
1565+
msg = (
1566+
f"Time limit {maxtime} in smart sampling reached, "
1567+
f"produced {len(y)} points"
1568+
)
1569+
warnings.warn(msg, RuntimeWarning)
1570+
break
15571571
xnew = 0.5 * (a + b)
15581572
ynew = f(xnew)
15591573
ymin = min(ymin, np.min(ynew))
@@ -1565,10 +1579,11 @@ def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3):
15651579
+ np.fromiter((y[bi] for bi in b), float)
15661580
)
15671581
dy = np.abs(ynew - yint)
1582+
dx = np.abs(b - a)
15681583

1569-
mask = dy > tol * (ymax - ymin)
1570-
1571-
# intervals which do not pass interpolation test
1584+
# in next iteration, handle intervals which do not
1585+
# pass interpolation test and are not too narrow
1586+
mask = (dy > tol * (ymax - ymin)) & (dx > tol * abs(xmax - xmin))
15721587
a = a[mask]
15731588
b = b[mask]
15741589
xnew = xnew[mask]

tests/test_issue.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from iminuit import Minuit
2-
from iminuit.util import IMinuitWarning
3-
import pickle
4-
import pytest
51
import numpy as np
62

73

84
def test_issue_424():
5+
from iminuit import Minuit
6+
97
def fcn(x, y, z):
108
return (x - 1) ** 2 + (y - 4) ** 2 / 2 + (z - 9) ** 2 / 3
119

@@ -20,6 +18,10 @@ def fcn(x, y, z):
2018

2119

2220
def test_issue_544():
21+
import pytest
22+
from iminuit import Minuit
23+
from iminuit.util import IMinuitWarning
24+
2325
def fcn(x, y):
2426
return x**2 + y**2
2527

@@ -30,6 +32,8 @@ def fcn(x, y):
3032

3133

3234
def test_issue_648():
35+
from iminuit import Minuit
36+
3337
class F:
3438
first = True
3539

@@ -45,6 +49,8 @@ def __call__(self, a, b):
4549

4650

4751
def test_issue_643():
52+
from iminuit import Minuit
53+
4854
def fcn(x, y, z):
4955
return (x - 2) ** 2 + (y - 3) ** 2 + (z - 4) ** 2
5056

@@ -64,6 +70,8 @@ def fcn(x, y, z):
6470

6571

6672
def test_issue_669():
73+
from iminuit import Minuit
74+
6775
def fcn(x, y):
6876
return x**2 + (y / 2) ** 2
6977

@@ -84,15 +92,21 @@ def fcn(x, y):
8492
assert match
8593

8694

95+
# cannot define this inside function, pickle will not allow it
8796
def fcn(par):
8897
return np.sum(par**2)
8998

9099

100+
# cannot define this inside function, pickle will not allow it
91101
def grad(par):
92102
return 2 * par
93103

94104

95105
def test_issue_687():
106+
import pickle
107+
import numpy as np
108+
from iminuit import Minuit
109+
96110
start = np.zeros(3)
97111
m = Minuit(fcn, start)
98112

@@ -107,10 +121,13 @@ def test_issue_687():
107121

108122

109123
def test_issue_694():
110-
stats = pytest.importorskip("scipy.stats")
111-
124+
import pytest
125+
import numpy as np
126+
from iminuit import Minuit
112127
from iminuit.cost import ExtendedUnbinnedNLL
113128

129+
stats = pytest.importorskip("scipy.stats")
130+
114131
xmus = 1.0
115132
xmub = 5.0
116133
xsigma = 1.0
@@ -142,3 +159,32 @@ def model(x, sig_n, sig_mu, sig_sigma, bkg_n, bkg_tau):
142159
break
143160
else:
144161
assert False
162+
163+
164+
def test_issue_923():
165+
from iminuit import Minuit
166+
from iminuit.cost import LeastSquares
167+
import numpy as np
168+
import pytest
169+
170+
# implicitly needed by visualize
171+
pytest.importorskip("matplotlib")
172+
173+
def model(x, c1):
174+
c2 = 100
175+
res = np.zeros(len(x))
176+
mask = x < 47
177+
res[mask] = c1
178+
res[~mask] = c2
179+
return res
180+
181+
xtest = np.linspace(0, 74)
182+
ytest = xtest * 0 + 1
183+
ytesterr = ytest
184+
185+
least_squares = LeastSquares(xtest, ytest, ytesterr, model)
186+
187+
m = Minuit(least_squares, c1=1)
188+
m.migrad()
189+
# this used to trigger an endless (?) loop
190+
m.visualize()

tests/test_util.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,30 @@ def test_smart_sampling_1(fn_expected):
711711

712712

713713
def test_smart_sampling_2():
714-
with pytest.warns(RuntimeWarning):
715-
util._smart_sampling(np.log, 1e-10, 1, tol=1e-10)
714+
# should not raise a warning
715+
x, y = util._smart_sampling(np.log, 1e-10, 1, tol=1e-5)
716+
assert 0 < len(x) < 1000
717+
718+
719+
def test_smart_sampling_3():
720+
def step(x):
721+
return np.where(x > 0.5, 0, 1)
722+
723+
with pytest.warns(RuntimeWarning, match="Iteration limit"):
724+
x, y = util._smart_sampling(step, 0, 1, tol=0)
725+
assert 0 < len(x) < 80
726+
727+
728+
def test_smart_sampling_4():
729+
from time import sleep
730+
731+
def step(x):
732+
sleep(0.1)
733+
return np.where(x > 0.5, 0, 1)
734+
735+
with pytest.warns(RuntimeWarning, match="Time limit"):
736+
x, y = util._smart_sampling(step, 0, 1, maxtime=0)
737+
assert 0 < len(x) < 10
716738

717739

718740
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)