Skip to content

Commit 7af0f0b

Browse files
committed
fix point distribution for loss strategy
1 parent c5b5917 commit 7af0f0b

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def strategy(self, strategy):
111111
raise ValueError(
112112
'Only strategy="loss_improvements", strategy="loss", or'
113113
' strategy="npoints" is implemented.')
114+
self._points = {} # reset the cache
114115

115116
def _ask_and_tell_based_on_loss_improvements(self, n):
116117
chosen_points = []
@@ -125,12 +126,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125126
self._points[index] = learner.ask(
126127
n=1, tell_pending=False)
127128
points, loss_improvements = self._points[index]
128-
npoints = npoints_per_learner[index] + learner.npoints
129+
npoints = (npoints_per_learner[index]
130+
+ learner.npoints
131+
+ len(learner.pending_points))
129132
priority = (loss_improvements[0], -npoints)
130133
improvements_per_learner.append(priority)
131134
points_per_learner.append((index, points[0]))
132135

133-
# Chose the optimal improvement.
136+
# Choose the optimal improvement.
134137
(index, point), (loss_improvement, _) = max(
135138
zip(points_per_learner, improvements_per_learner),
136139
key=itemgetter(1))
@@ -142,15 +145,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142145
return chosen_points, chosen_loss_improvements
143146

144147
def _ask_and_tell_based_on_loss(self, n):
145-
points = []
146-
loss_improvements = []
148+
chosen_points = []
149+
chosen_loss_improvements = []
150+
npoints_per_learner = defaultdict(int)
151+
147152
for _ in range(n):
148153
losses = self._losses(real=False)
149-
max_ind = np.argmax(losses)
150-
xs, ls = self.learners[max_ind].ask(1)
151-
points.append((max_ind, xs[0]))
152-
loss_improvements.append(ls[0])
153-
return points, loss_improvements
154+
npoints = [-(l.npoints
155+
+ npoints_per_learner[i]
156+
+ len(l.pending_points))
157+
for i, l in enumerate(self.learners)]
158+
priority = zip(losses, npoints)
159+
index, (_, npoints) = max(enumerate(priority), key=itemgetter(1))
160+
npoints_per_learner[index] += 1
161+
points, loss_improvements = self.learners[index].ask(1)
162+
chosen_points.append((index, points[0]))
163+
chosen_loss_improvements.append(loss_improvements[0])
164+
return chosen_points, chosen_loss_improvements
154165

155166
def _ask_and_tell_based_on_npoints(self, n):
156167
points = []

adaptive/tests/test_balancing_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_distribute_first_points_over_learners():
3131
points, _ = learner.ask(100)
3232
i_learner, xs = zip(*points)
3333
# assert that are all learners in the suggested points
34-
assert len(set(i_learner)) == len(learners)
34+
assert len(set(i_learner)) == len(learners), strategy
3535

3636

3737
def test_strategies():

0 commit comments

Comments
 (0)