@@ -119,7 +119,9 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
119
119
default = optimizers .DEFAULT_RANDOM_RESTARTS , kw_only = True
120
120
)
121
121
_num_seed_trials : int = attr .field (default = 1 , kw_only = True )
122
- _linear_coef : float = attr .field (default = 0.0 , kw_only = True )
122
+ # Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
123
+ # which uses a sum of Matern and linear but ARD still tunes its amplitude.
124
+ _linear_coef : float = attr .field (default = 1.0 , kw_only = True )
123
125
_scoring_function_factory : acq_lib .ScoringFunctionFactory = attr .field (
124
126
factory = lambda : default_scoring_function_factory ,
125
127
kw_only = True ,
@@ -142,6 +144,11 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
142
144
factory = output_warpers .create_default_warper , kw_only = True
143
145
)
144
146
147
+ # Multi-objective settings.
148
+ _num_scalars : int = attr .field (default = 1000 , kw_only = True )
149
+ _ref_scaling : float = attr .field (default = 0.01 , kw_only = True )
150
+ _num_samples : Optional [int ] = attr .field (default = None , kw_only = True )
151
+
145
152
# ------------------------------------------------------------------
146
153
# Internal attributes which should not be set by callers.
147
154
# ------------------------------------------------------------------
@@ -182,6 +189,57 @@ def __attrs_post_init__(self):
182
189
seed = int (jax .random .randint (self ._rng , [], 0 , 2 ** 16 )),
183
190
)
184
191
192
+ m_info = self ._problem .metric_information
193
+ if not m_info .is_single_objective :
194
+ num_obj = len (m_info .of_type (vz .MetricType .OBJECTIVE ))
195
+ self ._rng , weights_rng = jax .random .split (self ._rng )
196
+ weights = jnp .abs (
197
+ jax .random .normal (weights_rng , shape = (self ._num_scalars , num_obj ))
198
+ )
199
+ weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
200
+
201
+ if self ._num_samples is None :
202
+
203
+ def _scalarized_ucb (
204
+ data : types .ModelData ,
205
+ ) -> acq_lib .AcquisitionFunction :
206
+ scalarizer = scalarization .HyperVolumeScalarization (
207
+ weights ,
208
+ acq_lib .get_reference_point (data .labels , self ._ref_scaling ),
209
+ )
210
+ return acq_lib .ScalarizedAcquisition (
211
+ acq_lib .UCB (),
212
+ scalarizer ,
213
+ reduction_fn = lambda x : jnp .mean (x , axis = 0 ),
214
+ )
215
+
216
+ acq_fn_factory = _scalarized_ucb
217
+
218
+ else :
219
+
220
+ def _scalarized_sample_ehvi (
221
+ data : types .ModelData ,
222
+ ) -> acq_lib .AcquisitionFunction :
223
+ scalarizer = scalarization .HyperVolumeScalarization (
224
+ weights ,
225
+ acq_lib .get_reference_point (data .labels , self ._ref_scaling ),
226
+ )
227
+ return acq_lib .ScalarizedAcquisition (
228
+ acq_lib .Sample (self ._num_samples ),
229
+ scalarizer ,
230
+ # We need to reduce across the scalarization and sample axes.
231
+ reduction_fn = lambda x : jnp .mean (jax .nn .relu (x ), axis = [0 , 1 ]),
232
+ )
233
+
234
+ acq_fn_factory = _scalarized_sample_ehvi
235
+
236
+ # Multi-objective overrides.
237
+ self ._scoring_function_factory = (
238
+ acq_lib .bayesian_scoring_function_factory (acq_fn_factory )
239
+ )
240
+ self ._scoring_function_is_parallel = True
241
+ self ._use_trust_region = False
242
+
185
243
self ._acquisition_optimizer = self ._acquisition_optimizer_factory (
186
244
self ._converter
187
245
)
@@ -578,67 +636,7 @@ def from_problem(
578
636
cls ,
579
637
problem : vz .ProblemStatement ,
580
638
seed : Optional [int ] = None ,
581
- num_scalarizations : int = 1000 ,
582
- reference_scaling : float = 0.01 ,
583
- num_samples : int | None = None ,
584
639
** kwargs ,
585
640
) -> 'VizierGPBandit' :
586
641
rng = jax .random .PRNGKey (seed or 0 )
587
- # Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
588
- # which uses a sum of Matern and linear but ARD still tunes its amplitude.
589
- if problem .is_single_objective :
590
- return cls (problem , linear_coef = 1.0 , rng = rng , ** kwargs )
591
- else :
592
- num_obj = len (problem .metric_information .of_type (vz .MetricType .OBJECTIVE ))
593
- rng , weights_rng = jax .random .split (rng )
594
- weights = jnp .abs (
595
- jax .random .normal (weights_rng , shape = (num_scalarizations , num_obj ))
596
- )
597
- weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
598
-
599
- if num_samples is None :
600
-
601
- def _scalarized_ucb (
602
- data : types .ModelData ,
603
- ) -> acq_lib .AcquisitionFunction :
604
- scalarizer = scalarization .HyperVolumeScalarization (
605
- weights ,
606
- acq_lib .get_reference_point (data .labels , reference_scaling ),
607
- )
608
- return acq_lib .ScalarizedAcquisition (
609
- acq_lib .UCB (),
610
- scalarizer ,
611
- reduction_fn = lambda x : jnp .mean (x , axis = 0 ),
612
- )
613
-
614
- acq_fn_factory = _scalarized_ucb
615
- else :
616
-
617
- def _scalarized_sample_ehvi (
618
- data : types .ModelData ,
619
- ) -> acq_lib .AcquisitionFunction :
620
- scalarizer = scalarization .HyperVolumeScalarization (
621
- weights ,
622
- acq_lib .get_reference_point (data .labels , reference_scaling ),
623
- )
624
- return acq_lib .ScalarizedAcquisition (
625
- acq_lib .Sample (num_samples ),
626
- scalarizer ,
627
- # We need to reduce across the scalarization and sample axes.
628
- reduction_fn = lambda x : jnp .mean (jax .nn .relu (x ), axis = [0 , 1 ]),
629
- )
630
-
631
- acq_fn_factory = _scalarized_sample_ehvi
632
-
633
- scoring_function_factory = acq_lib .bayesian_scoring_function_factory (
634
- acq_fn_factory
635
- )
636
- return cls (
637
- problem ,
638
- linear_coef = 1.0 ,
639
- scoring_function_factory = scoring_function_factory ,
640
- scoring_function_is_parallel = True ,
641
- use_trust_region = False ,
642
- rng = rng ,
643
- ** kwargs ,
644
- )
642
+ return cls (problem , rng = rng , ** kwargs )
0 commit comments