Skip to content

Commit b0f5dd4

Browse files
committed
use a list for ((learner_index, point), loss_improvement) tuples
1 parent d60a88e commit b0f5dd4

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,8 @@ def strategy(self, strategy):
113113
' strategy="npoints" is implemented.')
114114

115115
def _ask_and_tell_based_on_loss_improvements(self, n):
116-
chosen_points = []
117-
chosen_loss_improvements = []
118-
npoints = [l.npoints + len(l.pending_points)
119-
for l in self.learners]
116+
selected = [] # tuples ((learner_index, point), loss_improvement)
117+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
120118
for _ in range(n):
121119
improvements_per_learner = []
122120
points_per_learner = []
@@ -136,17 +134,15 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
136134
zip(points_per_learner, improvements_per_learner),
137135
key=itemgetter(1))
138136
npoints[index] += 1
139-
chosen_points.append((index, point))
140-
chosen_loss_improvements.append(loss_improvement)
137+
selected.append(((index, point), loss_improvement))
141138
self.tell_pending((index, point))
142139

143-
return chosen_points, chosen_loss_improvements
140+
points, loss_improvements = map(list, zip(*selected))
141+
return points, loss_improvements
144142

145143
def _ask_and_tell_based_on_loss(self, n):
146-
chosen_points = []
147-
chosen_loss_improvements = []
148-
npoints = [l.npoints + len(l.pending_points)
149-
for l in self.learners]
144+
selected = [] # tuples ((learner_index, point), loss_improvement)
145+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
150146
for _ in range(n):
151147
losses = self._losses(real=False)
152148
priority = zip(losses, (-n for n in npoints))
@@ -158,15 +154,14 @@ def _ask_and_tell_based_on_loss(self, n):
158154
self._ask_cache[index] = self.learners[index].ask(n=1)
159155
points, loss_improvements = self._ask_cache[index]
160156

161-
chosen_points.append((index, points[0]))
162-
chosen_loss_improvements.append(loss_improvements[0])
163-
return chosen_points, chosen_loss_improvements
157+
selected.append(((index, points[0]), loss_improvements[0]))
158+
159+
points, loss_improvements = map(list, zip(*selected))
160+
return points, loss_improvements
164161

165162
def _ask_and_tell_based_on_npoints(self, n):
166-
chosen_points = []
167-
chosen_loss_improvements = []
168-
npoints = [l.npoints + len(l.pending_points)
169-
for l in self.learners]
163+
selected = [] # tuples ((learner_index, point), loss_improvement)
164+
npoints = [l.npoints + len(l.pending_points) for l in self.learners]
170165
n_left = n
171166
while n_left > 0:
172167
index = np.argmin(npoints)
@@ -176,9 +171,10 @@ def _ask_and_tell_based_on_npoints(self, n):
176171
points, loss_improvements = self._ask_cache[index]
177172
npoints[index] += 1
178173
n_left -= 1
179-
chosen_points.append((index, points[0]))
180-
chosen_loss_improvements.append(loss_improvements[0])
181-
return chosen_points, chosen_loss_improvements
174+
selected.append(((index, points[0]), loss_improvements[0]))
175+
176+
points, loss_improvements = map(list, zip(*selected))
177+
return points, loss_improvements
182178

183179
def ask(self, n, tell_pending=True):
184180
"""Chose points for learners."""

0 commit comments

Comments
 (0)