Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,9 @@ def predict(

@classmethod
def from_problem(
cls, problem: vz.ProblemStatement, seed: Optional[int] = None
cls,
problem: vz.ProblemStatement,
seed: Optional[int] = None,
) -> 'VizierGPBandit':
rng = jax.random.PRNGKey(seed or 0)
# Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
Expand Down
35 changes: 35 additions & 0 deletions vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
from vizier._src.algorithms.optimizers import eagle_strategy as es
from vizier._src.algorithms.optimizers import lbfgsb_optimizer as lo
from vizier._src.algorithms.optimizers import vectorized_base as vb
from vizier._src.algorithms.testing import simplekd_runner
from vizier._src.algorithms.testing import test_runners
from vizier._src.benchmarks.experimenters.synthetic import simplekd
from vizier._src.jax import types
from vizier.jax import optimizers
from vizier.pyvizier import converters
Expand Down Expand Up @@ -494,6 +496,39 @@ def test_multi_metrics(self):
)


class GPBanditSimplekDTest(parameterized.TestCase):
"""Simplekd convergence tests for gp bandit designer."""

@parameterized.parameters(
dict(best_category='corner', max_relative_error=0.5),
dict(best_category='center', max_relative_error=0.1),
dict(best_category='mixed', max_relative_error=0.1),
)
def test_convergence(
self,
best_category: simplekd.SimpleKDCategory,
*,
max_relative_error: float,
) -> None:
simplekd_runner.SimpleKDConvergenceTester(
best_category=best_category,
designer_factory=(
# pylint: disable=g-long-lambda
lambda problem, seed: gp_bandit.VizierGPBandit(
problem,
rng=jax.random.PRNGKey(seed),
padding_schedule=padding.PaddingSchedule(
num_trials=padding.PaddingType.MULTIPLES_OF_10
),
)
),
num_trials=20,
max_relative_error=max_relative_error,
num_repeats=1,
target_num_convergence=1,
).assert_convergence()


# TODO: Fix transfer learning and enable tests.
@unittest.skip('The current transfer learning seems broken and test failing.')
class GPBanditPriorsTest(parameterized.TestCase):
Expand Down
11 changes: 7 additions & 4 deletions vizier/_src/benchmarks/runners/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def run(self, state: benchmark_state.BenchmarkState) -> None:
suggestions = state.algorithm.suggest(self.batch_size)
if not suggestions:
logging.info(
(
'Algorithm did not generate %d suggestions'
'because it returned nothing.'
),
'Algorithm returned 0 suggestions. Expected: %s.',
self.batch_size,
)
logging.info('Generated %s suggestions.', len(suggestions))
state.experimenter.evaluate(list(suggestions))
for t in suggestions:
logging.info('Trial %s: %s', t.id, t.final_measurement)


@attr.define
Expand Down Expand Up @@ -165,6 +165,9 @@ def run(self, state: benchmark_state.BenchmarkState) -> None:
evaluated_trials = active_trials[: self.num_evaluations]

state.experimenter.evaluate(evaluated_trials)
logging.info('Evaluated %s trials.', len(evaluated_trials))
for t in evaluated_trials:
logging.info('Trial %s: %s', t.id, t.final_measurement)


@attr.define
Expand Down