Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
# pickle the whole learner.
self.function = partial(dispatch, [l.function for l in self.learners])

self._points = {}
self._ask_cache = {}
self._loss = {}
self._pending_loss = {}
self._cdims_default = cdims
Expand Down Expand Up @@ -113,54 +113,74 @@ def strategy(self, strategy):
' strategy="npoints" is implemented.')

def _ask_and_tell_based_on_loss_improvements(self, n):
points = []
loss_improvements = []
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
improvements_per_learner = []
pairs = []
to_select = []
for index, learner in enumerate(self.learners):
if index not in self._points:
self._points[index] = learner.ask(
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = learner.ask(
n=1, tell_pending=False)
point, loss_improvement = self._points[index]
improvements_per_learner.append(loss_improvement[0])
pairs.append((index, point[0]))
x, l = max(zip(pairs, improvements_per_learner),
key=itemgetter(1))
points.append(x)
loss_improvements.append(l)
self.tell_pending(x)

points, loss_improvements = self._ask_cache[index]
to_select.append(
((index, points[0]),
(loss_improvements[0], -total_points[index]))
)

# Choose the optimal improvement.
(index, point), (loss_improvement, _) = max(
to_select, key=itemgetter(1))
total_points[index] += 1
selected.append(((index, point), loss_improvement))
self.tell_pending((index, point))

points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_loss(self, n):
points = []
loss_improvements = []
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
losses = self._losses(real=False)
max_ind = np.argmax(losses)
xs, ls = self.learners[max_ind].ask(1)
points.append((max_ind, xs[0]))
loss_improvements.append(ls[0])
index, _ = max(
enumerate(zip(losses, (-n for n in total_points))),
key=itemgetter(1)
)
total_points[index] += 1

# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]

selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))

points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def _ask_and_tell_based_on_npoints(self, n):
points = []
loss_improvements = []
npoints = [l.npoints + len(l.pending_points)
for l in self.learners]
n_left = n
while n_left > 0:
i = np.argmin(npoints)
xs, ls = self.learners[i].ask(1)
npoints[i] += 1
n_left -= 1
points.append((i, xs[0]))
loss_improvements.append(ls[0])
selected = [] # tuples ((learner_index, point), loss_improvement)
total_points = [l.npoints + len(l.pending_points) for l in self.learners]
for _ in range(n):
index = np.argmin(total_points)
# Take the points from the cache
if index not in self._ask_cache:
self._ask_cache[index] = self.learners[index].ask(n=1)
points, loss_improvements = self._ask_cache[index]
total_points[index] += 1
selected.append(((index, points[0]), loss_improvements[0]))
self.tell_pending((index, points[0]))

points, loss_improvements = map(list, zip(*selected))
return points, loss_improvements

def ask(self, n, tell_pending=True):
"""Chose points for learners."""
if n == 0:
return [], []

if not tell_pending:
with restore(*self.learners):
return self._ask_and_tell(n)
Expand All @@ -169,14 +189,14 @@ def ask(self, n, tell_pending=True):

def tell(self, x, y):
index, x = x
self._points.pop(index, None)
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self._pending_loss.pop(index, None)
self.learners[index].tell(x, y)

def tell_pending(self, x):
index, x = x
self._points.pop(index, None)
self._ask_cache.pop(index, None)
self._loss.pop(index, None)
self.learners[index].tell_pending(x)

Expand Down
37 changes: 37 additions & 0 deletions adaptive/tests/test_balancing_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# -*- coding: utf-8 -*-

import pytest

from adaptive.learner import Learner1D, BalancingLearner
from adaptive.runner import simple


def test_balancing_learner_loss_cache():
Expand All @@ -21,3 +24,37 @@ def test_balancing_learner_loss_cache():
bl = BalancingLearner([learner])
assert bl.loss(real=False) == pending_loss
assert bl.loss(real=True) == real_loss


@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
def test_distribute_first_points_over_learners(strategy):
for initial_points in [0, 3]:
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
learner = BalancingLearner(learners, strategy=strategy)

points = learner.ask(initial_points)[0]
learner.tell_many(points, points)

points, _ = learner.ask(100)
i_learner, xs = zip(*points)
# assert that are all learners in the suggested points
assert len(set(i_learner)) == len(learners)


@pytest.mark.parametrize('strategy', ['loss', 'loss_improvements', 'npoints'])
def test_ask_0(strategy):
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
learner = BalancingLearner(learners, strategy=strategy)
points, _ = learner.ask(0)
assert len(points) == 0


@pytest.mark.parametrize('strategy, goal', [
('loss', lambda l: l.loss() < 0.1),
('loss_improvements', lambda l: l.loss() < 0.1),
('npoints', lambda bl: all(l.npoints > 10 for l in bl.learners)),
])
def test_strategies(strategy, goal):
learners = [Learner1D(lambda x: x, bounds=(-1, 1)) for i in range(10)]
learner = BalancingLearner(learners, strategy=strategy)
simple(learner, goal=goal)