Skip to content

Commit 6e172ef

Browse files
committed
fix the point distribution issue by max(loss, -npoints)
1 parent 526ad43 commit 6e172ef

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,33 @@ def strategy(self, strategy):
112112
' strategy="npoints" is implemented.')
113113

114114
def _ask_and_tell_based_on_loss_improvements(self, n):
115-
points = []
116-
loss_improvements = []
115+
chosen_points = []
116+
chosen_loss_improvements = []
117+
npoints_per_learner = defaultdict(int)
118+
117119
for _ in range(n):
118120
improvements_per_learner = []
119-
pairs = []
121+
points_per_learner = []
120122
for index, learner in enumerate(self.learners):
121123
if index not in self._points:
122124
self._points[index] = learner.ask(
123125
n=1, tell_pending=False)
124-
point, loss_improvement = self._points[index]
125-
improvements_per_learner.append(loss_improvement[0])
126-
pairs.append((index, point[0]))
127-
x, l = max(zip(pairs, improvements_per_learner),
128-
key=itemgetter(1))
129-
points.append(x)
130-
loss_improvements.append(l)
131-
self.tell_pending(x)
132-
133-
return points, loss_improvements
126+
points, loss_improvements = self._points[index]
127+
npoints = npoints_per_learner[index] + learner.npoints
128+
sort_tuple = (loss_improvements[0], -npoints)
129+
improvements_per_learner.append(sort_tuple)
130+
points_per_learner.append((index, points[0]))
131+
132+
# Chose the optimal improvement.
133+
(index, point), (loss_improvement, _) = max(
134+
zip(points_per_learner, improvements_per_learner),
135+
key=itemgetter(1))
136+
npoints_per_learner[index] += 1
137+
chosen_points.append((index, point))
138+
chosen_loss_improvements.append(loss_improvement)
139+
self.tell_pending((index, point))
140+
141+
return chosen_points, chosen_loss_improvements
134142

135143
def _ask_and_tell_based_on_loss(self, n):
136144
points = []
@@ -160,19 +168,11 @@ def _ask_and_tell_based_on_npoints(self, n):
160168

161169
def ask(self, n, tell_pending=True):
162170
"""Chose points for learners."""
163-
if any(l.npoints for l in self.learners):
164-
ask_and_tell = self._ask_and_tell
165-
else:
166-
# If there are no data points yet,
167-
# distribute the points over all learners.
168-
# See https://github.com/python-adaptive/adaptive/issues/159
169-
ask_and_tell = self._ask_and_tell_based_on_npoints
170-
171171
if not tell_pending:
172172
with restore(*self.learners):
173-
return ask_and_tell(n)
173+
return self._ask_and_tell(n)
174174
else:
175-
return ask_and_tell(n)
175+
return self._ask_and_tell(n)
176176

177177
def tell(self, x, y):
178178
index, x = x

0 commit comments

Comments
 (0)