@@ -112,25 +112,33 @@ def strategy(self, strategy):
112
112
' strategy="npoints" is implemented.' )
113
113
114
114
def _ask_and_tell_based_on_loss_improvements (self , n ):
115
- points = []
116
- loss_improvements = []
115
+ chosen_points = []
116
+ chosen_loss_improvements = []
117
+ npoints_per_learner = defaultdict (int )
118
+
117
119
for _ in range (n ):
118
120
improvements_per_learner = []
119
- pairs = []
121
+ points_per_learner = []
120
122
for index , learner in enumerate (self .learners ):
121
123
if index not in self ._points :
122
124
self ._points [index ] = learner .ask (
123
125
n = 1 , tell_pending = False )
124
- point , loss_improvement = self ._points [index ]
125
- improvements_per_learner .append (loss_improvement [0 ])
126
- pairs .append ((index , point [0 ]))
127
- x , l = max (zip (pairs , improvements_per_learner ),
128
- key = itemgetter (1 ))
129
- points .append (x )
130
- loss_improvements .append (l )
131
- self .tell_pending (x )
132
-
133
- return points , loss_improvements
126
+ points , loss_improvements = self ._points [index ]
127
+ npoints = npoints_per_learner [index ] + learner .npoints
128
+ sort_tuple = (loss_improvements [0 ], - npoints )
129
+ improvements_per_learner .append (sort_tuple )
130
+ points_per_learner .append ((index , points [0 ]))
131
+
132
+ # Chose the optimal improvement.
133
+ (index , point ), (loss_improvement , _ ) = max (
134
+ zip (points_per_learner , improvements_per_learner ),
135
+ key = itemgetter (1 ))
136
+ npoints_per_learner [index ] += 1
137
+ chosen_points .append ((index , point ))
138
+ chosen_loss_improvements .append (loss_improvement )
139
+ self .tell_pending ((index , point ))
140
+
141
+ return chosen_points , chosen_loss_improvements
134
142
135
143
def _ask_and_tell_based_on_loss (self , n ):
136
144
points = []
@@ -160,19 +168,11 @@ def _ask_and_tell_based_on_npoints(self, n):
160
168
161
169
def ask (self , n , tell_pending = True ):
162
170
"""Chose points for learners."""
163
- if any (l .npoints for l in self .learners ):
164
- ask_and_tell = self ._ask_and_tell
165
- else :
166
- # If there are no data points yet,
167
- # distribute the points over all learners.
168
- # See https://github.com/python-adaptive/adaptive/issues/159
169
- ask_and_tell = self ._ask_and_tell_based_on_npoints
170
-
171
171
if not tell_pending :
172
172
with restore (* self .learners ):
173
- return ask_and_tell (n )
173
+ return self . _ask_and_tell (n )
174
174
else :
175
- return ask_and_tell (n )
175
+ return self . _ask_and_tell (n )
176
176
177
177
def tell (self , x , y ):
178
178
index , x = x
0 commit comments