Skip to content

Commit a9164e0

Browse files
xingyousongcopybara-github
authored andcommitted
Move multiobjective to simplify from_problem factory
PiperOrigin-RevId: 650835016
1 parent 9461e45 commit a9164e0

File tree

1 file changed

+60
-62
lines changed

1 file changed

+60
-62
lines changed

vizier/_src/algorithms/designers/gp_bandit.py

Lines changed: 60 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
119119
default=optimizers.DEFAULT_RANDOM_RESTARTS, kw_only=True
120120
)
121121
_num_seed_trials: int = attr.field(default=1, kw_only=True)
122-
_linear_coef: float = attr.field(default=0.0, kw_only=True)
122+
# Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
123+
# which uses a sum of Matern and linear but ARD still tunes its amplitude.
124+
_linear_coef: float = attr.field(default=1.0, kw_only=True)
123125
_scoring_function_factory: acq_lib.ScoringFunctionFactory = attr.field(
124126
factory=lambda: default_scoring_function_factory,
125127
kw_only=True,
@@ -142,6 +144,11 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
142144
factory=output_warpers.create_default_warper, kw_only=True
143145
)
144146

147+
# Multi-objective settings.
148+
_num_scalars: int = attr.field(default=1000, kw_only=True)
149+
_ref_scaling: float = attr.field(default=0.01, kw_only=True)
150+
_num_samples: Optional[int] = attr.field(default=None, kw_only=True)
151+
145152
# ------------------------------------------------------------------
146153
# Internal attributes which should not be set by callers.
147154
# ------------------------------------------------------------------
@@ -182,6 +189,57 @@ def __attrs_post_init__(self):
182189
seed=int(jax.random.randint(self._rng, [], 0, 2**16)),
183190
)
184191

192+
m_info = self._problem.metric_information
193+
if not m_info.is_single_objective:
194+
num_obj = len(m_info.of_type(vz.MetricType.OBJECTIVE))
195+
self._rng, weights_rng = jax.random.split(self._rng)
196+
weights = jnp.abs(
197+
jax.random.normal(weights_rng, shape=(self._num_scalars, num_obj))
198+
)
199+
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
200+
201+
if self._num_samples is None:
202+
203+
def _scalarized_ucb(
204+
data: types.ModelData,
205+
) -> acq_lib.AcquisitionFunction:
206+
scalarizer = scalarization.HyperVolumeScalarization(
207+
weights,
208+
acq_lib.get_reference_point(data.labels, self._ref_scaling),
209+
)
210+
return acq_lib.ScalarizedAcquisition(
211+
acq_lib.UCB(),
212+
scalarizer,
213+
reduction_fn=lambda x: jnp.mean(x, axis=0),
214+
)
215+
216+
acq_fn_factory = _scalarized_ucb
217+
218+
else:
219+
220+
def _scalarized_sample_ehvi(
221+
data: types.ModelData,
222+
) -> acq_lib.AcquisitionFunction:
223+
scalarizer = scalarization.HyperVolumeScalarization(
224+
weights,
225+
acq_lib.get_reference_point(data.labels, self._ref_scaling),
226+
)
227+
return acq_lib.ScalarizedAcquisition(
228+
acq_lib.Sample(self._num_samples),
229+
scalarizer,
230+
# We need to reduce across the scalarization and sample axes.
231+
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1]),
232+
)
233+
234+
acq_fn_factory = _scalarized_sample_ehvi
235+
236+
# Multi-objective overrides.
237+
self._scoring_function_factory = (
238+
acq_lib.bayesian_scoring_function_factory(acq_fn_factory)
239+
)
240+
self._scoring_function_is_parallel = True
241+
self._use_trust_region = False
242+
185243
self._acquisition_optimizer = self._acquisition_optimizer_factory(
186244
self._converter
187245
)
@@ -578,67 +636,7 @@ def from_problem(
578636
cls,
579637
problem: vz.ProblemStatement,
580638
seed: Optional[int] = None,
581-
num_scalarizations: int = 1000,
582-
reference_scaling: float = 0.01,
583-
num_samples: int | None = None,
584639
**kwargs,
585640
) -> 'VizierGPBandit':
586641
rng = jax.random.PRNGKey(seed or 0)
587-
# Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
588-
# which uses a sum of Matern and linear but ARD still tunes its amplitude.
589-
if problem.is_single_objective:
590-
return cls(problem, linear_coef=1.0, rng=rng, **kwargs)
591-
else:
592-
num_obj = len(problem.metric_information.of_type(vz.MetricType.OBJECTIVE))
593-
rng, weights_rng = jax.random.split(rng)
594-
weights = jnp.abs(
595-
jax.random.normal(weights_rng, shape=(num_scalarizations, num_obj))
596-
)
597-
weights = weights / jnp.linalg.norm(weights, axis=-1, keepdims=True)
598-
599-
if num_samples is None:
600-
601-
def _scalarized_ucb(
602-
data: types.ModelData,
603-
) -> acq_lib.AcquisitionFunction:
604-
scalarizer = scalarization.HyperVolumeScalarization(
605-
weights,
606-
acq_lib.get_reference_point(data.labels, reference_scaling),
607-
)
608-
return acq_lib.ScalarizedAcquisition(
609-
acq_lib.UCB(),
610-
scalarizer,
611-
reduction_fn=lambda x: jnp.mean(x, axis=0),
612-
)
613-
614-
acq_fn_factory = _scalarized_ucb
615-
else:
616-
617-
def _scalarized_sample_ehvi(
618-
data: types.ModelData,
619-
) -> acq_lib.AcquisitionFunction:
620-
scalarizer = scalarization.HyperVolumeScalarization(
621-
weights,
622-
acq_lib.get_reference_point(data.labels, reference_scaling),
623-
)
624-
return acq_lib.ScalarizedAcquisition(
625-
acq_lib.Sample(num_samples),
626-
scalarizer,
627-
# We need to reduce across the scalarization and sample axes.
628-
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1]),
629-
)
630-
631-
acq_fn_factory = _scalarized_sample_ehvi
632-
633-
scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
634-
acq_fn_factory
635-
)
636-
return cls(
637-
problem,
638-
linear_coef=1.0,
639-
scoring_function_factory=scoring_function_factory,
640-
scoring_function_is_parallel=True,
641-
use_trust_region=False,
642-
rng=rng,
643-
**kwargs,
644-
)
642+
return cls(problem, rng=rng, **kwargs)

0 commit comments

Comments
 (0)