Skip to content

Commit 5cb2392

Browse files
xingyousongcopybara-github
authored andcommitted
Allow arms to have alphabetical names.
PiperOrigin-RevId: 693518222
1 parent ed8cb96 commit 5cb2392

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

vizier/_src/benchmarks/experimenters/normalizing_experimenter_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def testNormalizationApply(self, func):
7474
self.assertBetween(normalized_value, -10, 10)
7575

7676
def test_NormalizingCategoricals(self):
77-
mab_exptr = multiarm.FixedMultiArmExperimenter(rewards=[-1e6, 0.0, 1e6])
77+
mab_exptr = multiarm.FixedMultiArmExperimenter(
78+
rewards=[-1e6, 0.0, 1e6], arms_as_chars=False
79+
)
7880
norm_exptr = normalizing_experimenter.NormalizingExperimenter(mab_exptr)
7981
metric_name = norm_exptr.problem_statement().metric_information.item().name
8082

vizier/_src/benchmarks/experimenters/synthetic/multiarm.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,55 +20,74 @@
2020
distributions.
2121
"""
2222

23+
import copy
2324
from typing import Optional, Sequence
2425

2526
import numpy as np
2627
from vizier import pyvizier as vz
2728
from vizier._src.benchmarks.experimenters import experimenter
2829

2930

30-
def _default_multiarm_problem(num_arms: int) -> vz.ProblemStatement:
31+
def _default_multiarm_problem(
32+
num_arms: int, arms_as_chars: bool
33+
) -> vz.ProblemStatement:
3134
"""Returns default multi-arm problem statement."""
3235
problem = vz.ProblemStatement()
3336
problem.metric_information.append(
3437
vz.MetricInformation(name="reward", goal=vz.ObjectiveMetricGoal.MAXIMIZE)
3538
)
39+
40+
if arms_as_chars:
41+
# Starts with 'a' character.
42+
feasible_values = [chr(i + 97) for i in range(num_arms)]
43+
else:
44+
feasible_values = [str(i) for i in range(num_arms)]
45+
3646
problem.search_space.root.add_categorical_param(
37-
name="arm", feasible_values=[str(i) for i in range(num_arms)]
47+
name="arm", feasible_values=feasible_values
3848
)
3949
return problem
4050

4151

4252
class BernoulliMultiArmExperimenter(experimenter.Experimenter):
4353
"""Uses a collection of Bernoulli arms with given probabilities."""
4454

45-
def __init__(self, probs: Sequence[float], seed: Optional[int] = None):
55+
def __init__(
56+
self,
57+
probs: Sequence[float],
58+
arms_as_chars: bool = True,
59+
seed: Optional[int] = None,
60+
):
4661
self._probs = probs
4762
self._rng = np.random.RandomState(seed)
63+
self._problem = _default_multiarm_problem(len(self._probs), arms_as_chars)
4864

4965
def problem_statement(self) -> vz.ProblemStatement:
50-
return _default_multiarm_problem(len(self._probs))
66+
return copy.deepcopy(self._problem)
5167

5268
def evaluate(self, suggestions: Sequence[vz.Trial]) -> None:
5369
"""Each arm has a fixed probability of outputting 0 or 1 reward."""
70+
feasibles = self._problem.search_space.parameters[0].feasible_values
5471
for suggestion in suggestions:
55-
arm = int(suggestion.parameters["arm"].value)
56-
prob = self._probs[arm]
72+
arm_index = feasibles.index(suggestion.parameters["arm"].value)
73+
prob = self._probs[arm_index]
5774
reward = self._rng.choice([0, 1], p=[1 - prob, prob])
5875
suggestion.final_measurement = vz.Measurement(metrics={"reward": reward})
5976

6077

6178
class FixedMultiArmExperimenter(experimenter.Experimenter):
6279
"""Rewards are deterministic."""
6380

64-
def __init__(self, rewards: Sequence[float]):
81+
def __init__(self, rewards: Sequence[float], arms_as_chars: bool = True):
6582
self._rewards = rewards
83+
self._problem = _default_multiarm_problem(len(self._rewards), arms_as_chars)
6684

6785
def problem_statement(self) -> vz.ProblemStatement:
68-
return _default_multiarm_problem(len(self._rewards))
86+
return copy.deepcopy(self._problem)
6987

7088
def evaluate(self, suggestions: Sequence[vz.Trial]) -> None:
89+
feasibles = self._problem.search_space.parameters[0].feasible_values
7190
for suggestion in suggestions:
72-
arm = int(suggestion.parameters["arm"].value)
73-
reward = self._rewards[arm]
91+
arm_index = feasibles.index(suggestion.parameters["arm"].value)
92+
reward = self._rewards[arm_index]
7493
suggestion.final_measurement = vz.Measurement(metrics={"reward": reward})

0 commit comments

Comments
 (0)