@@ -111,6 +111,7 @@ def strategy(self, strategy):
111
111
raise ValueError (
112
112
'Only strategy="loss_improvements", strategy="loss", or'
113
113
' strategy="npoints" is implemented.' )
114
+ self ._points = {} # reset the cache
114
115
115
116
def _ask_and_tell_based_on_loss_improvements (self , n ):
116
117
chosen_points = []
@@ -125,12 +126,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125
126
self ._points [index ] = learner .ask (
126
127
n = 1 , tell_pending = False )
127
128
points , loss_improvements = self ._points [index ]
128
- npoints = npoints_per_learner [index ] + learner .npoints
129
+ npoints = (npoints_per_learner [index ]
130
+ + learner .npoints
131
+ + len (learner .pending_points ))
129
132
priority = (loss_improvements [0 ], - npoints )
130
133
improvements_per_learner .append (priority )
131
134
points_per_learner .append ((index , points [0 ]))
132
135
133
- # Chose the optimal improvement.
136
+ # Choose the optimal improvement.
134
137
(index , point ), (loss_improvement , _ ) = max (
135
138
zip (points_per_learner , improvements_per_learner ),
136
139
key = itemgetter (1 ))
@@ -142,15 +145,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142
145
return chosen_points , chosen_loss_improvements
143
146
144
147
def _ask_and_tell_based_on_loss (self , n ):
145
- points = []
146
- loss_improvements = []
148
+ chosen_points = []
149
+ chosen_loss_improvements = []
150
+ npoints_per_learner = defaultdict (int )
151
+
147
152
for _ in range (n ):
148
153
losses = self ._losses (real = False )
149
- max_ind = np .argmax (losses )
150
- xs , ls = self .learners [max_ind ].ask (1 )
151
- points .append ((max_ind , xs [0 ]))
152
- loss_improvements .append (ls [0 ])
153
- return points , loss_improvements
154
+ npoints = [- (l .npoints
155
+ + npoints_per_learner [i ]
156
+ + len (l .pending_points ))
157
+ for i , l in enumerate (self .learners )]
158
+ priority = zip (losses , npoints )
159
+ index , (_ , npoints ) = max (enumerate (priority ), key = itemgetter (1 ))
160
+ npoints_per_learner [index ] += 1
161
+ points , loss_improvements = self .learners [index ].ask (1 )
162
+ chosen_points .append ((index , points [0 ]))
163
+ chosen_loss_improvements .append (loss_improvements [0 ])
164
+ return chosen_points , chosen_loss_improvements
154
165
155
166
def _ask_and_tell_based_on_npoints (self , n ):
156
167
points = []
0 commit comments