@@ -119,7 +119,9 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
119
119
default = optimizers .DEFAULT_RANDOM_RESTARTS , kw_only = True
120
120
)
121
121
_num_seed_trials : int = attr .field (default = 1 , kw_only = True )
122
- _linear_coef : float = attr .field (default = 0.0 , kw_only = True )
122
+ # If used, should set to 1.0 as prior uses a sum of Matern and linear but ARD
123
+ # still tunes its amplitude. Only used for single-objective.
124
+ _linear_coef : Optional [float ] = attr .field (default = None , kw_only = True )
123
125
_scoring_function_factory : acq_lib .ScoringFunctionFactory = attr .field (
124
126
factory = lambda : default_scoring_function_factory ,
125
127
kw_only = True ,
@@ -578,67 +580,61 @@ def from_problem(
578
580
cls ,
579
581
problem : vz .ProblemStatement ,
580
582
seed : Optional [int ] = None ,
583
+ * , # Below are multi-objective options for acquisition function.
581
584
num_scalarizations : int = 1000 ,
582
585
reference_scaling : float = 0.01 ,
583
586
num_samples : int | None = None ,
584
587
** kwargs ,
585
588
) -> 'VizierGPBandit' :
586
589
rng = jax .random .PRNGKey (seed or 0 )
587
- # Linear coef is set to 1.0 as prior and uses VizierLinearGaussianProcess
588
- # which uses a sum of Matern and linear but ARD still tunes its amplitude.
589
590
if problem .is_single_objective :
590
- return cls (problem , linear_coef = 1.0 , rng = rng , ** kwargs )
591
+ return cls (problem , rng = rng , ** kwargs )
592
+
593
+ # Multi-objective.
594
+ num_obj = len (problem .metric_information .of_type (vz .MetricType .OBJECTIVE ))
595
+ rng , weights_rng = jax .random .split (rng )
596
+ weights = jnp .abs (
597
+ jax .random .normal (weights_rng , shape = (num_scalarizations , num_obj ))
598
+ )
599
+ weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
600
+
601
+ if num_samples is None :
602
+
603
+ def acq_fn_factory (data : types .ModelData ) -> acq_lib .AcquisitionFunction :
604
+ # Scalarized UCB.
605
+ scalarizer = scalarization .HyperVolumeScalarization (
606
+ weights ,
607
+ acq_lib .get_reference_point (data .labels , reference_scaling ),
608
+ )
609
+ return acq_lib .ScalarizedAcquisition (
610
+ acq_lib .UCB (),
611
+ scalarizer ,
612
+ reduction_fn = lambda x : jnp .mean (x , axis = 0 ),
613
+ )
614
+
591
615
else :
592
- num_obj = len (problem .metric_information .of_type (vz .MetricType .OBJECTIVE ))
593
- rng , weights_rng = jax .random .split (rng )
594
- weights = jnp .abs (
595
- jax .random .normal (weights_rng , shape = (num_scalarizations , num_obj ))
596
- )
597
- weights = weights / jnp .linalg .norm (weights , axis = - 1 , keepdims = True )
598
-
599
- if num_samples is None :
600
-
601
- def _scalarized_ucb (
602
- data : types .ModelData ,
603
- ) -> acq_lib .AcquisitionFunction :
604
- scalarizer = scalarization .HyperVolumeScalarization (
605
- weights ,
606
- acq_lib .get_reference_point (data .labels , reference_scaling ),
607
- )
608
- return acq_lib .ScalarizedAcquisition (
609
- acq_lib .UCB (),
610
- scalarizer ,
611
- reduction_fn = lambda x : jnp .mean (x , axis = 0 ),
612
- )
613
-
614
- acq_fn_factory = _scalarized_ucb
615
- else :
616
616
617
- def _scalarized_sample_ehvi (
618
- data : types .ModelData ,
619
- ) -> acq_lib .AcquisitionFunction :
620
- scalarizer = scalarization .HyperVolumeScalarization (
621
- weights ,
622
- acq_lib .get_reference_point (data .labels , reference_scaling ),
623
- )
624
- return acq_lib .ScalarizedAcquisition (
625
- acq_lib .Sample (num_samples ),
626
- scalarizer ,
627
- # We need to reduce across the scalarization and sample axes.
628
- reduction_fn = lambda x : jnp .mean (jax .nn .relu (x ), axis = [0 , 1 ]),
629
- )
630
-
631
- acq_fn_factory = _scalarized_sample_ehvi
632
-
633
- scoring_function_factory = acq_lib .bayesian_scoring_function_factory (
634
- acq_fn_factory
635
- )
636
- return cls (
637
- problem ,
638
- linear_coef = 1.0 ,
639
- scoring_function_factory = scoring_function_factory ,
640
- scoring_function_is_parallel = True ,
641
- use_trust_region = False ,
642
- rng = rng ,
643
- ** kwargs ,
644
- )
617
+ def acq_fn_factory (data : types .ModelData ) -> acq_lib .AcquisitionFunction :
618
+ # Sampled EHVI.
619
+ scalarizer = scalarization .HyperVolumeScalarization (
620
+ weights ,
621
+ acq_lib .get_reference_point (data .labels , reference_scaling ),
622
+ )
623
+ return acq_lib .ScalarizedAcquisition (
624
+ acq_lib .Sample (num_samples ),
625
+ scalarizer ,
626
+ # We need to reduce across the scalarization and sample axes.
627
+ reduction_fn = lambda x : jnp .mean (jax .nn .relu (x ), axis = [0 , 1 ]),
628
+ )
629
+
630
+ scoring_function_factory = acq_lib .bayesian_scoring_function_factory (
631
+ acq_fn_factory
632
+ )
633
+ return cls (
634
+ problem ,
635
+ scoring_function_factory = scoring_function_factory ,
636
+ scoring_function_is_parallel = True ,
637
+ use_trust_region = False ,
638
+ rng = rng ,
639
+ ** kwargs ,
640
+ )
0 commit comments