@@ -457,6 +457,9 @@ class VizierGPUCBPEBandit(vza.Designer):
457
457
Attributes:
458
458
problem: Must be a flat study with a single metric.
459
459
acquisition_optimizer:
460
+ gp_model_class: The GP model class, which must implement a `build_model`
461
+ class method that takes `ModelInput` and returns a
462
+ `StochasticProcessModel`.
460
463
metadata_ns: Metadata namespace that this designer writes to.
461
464
use_trust_region: Uses trust region.
462
465
ard_optimizer: An optimizer object, which should return a batch of
@@ -475,6 +478,10 @@ class VizierGPUCBPEBandit(vza.Designer):
475
478
kw_only = True ,
476
479
factory = lambda : VizierGPUCBPEBandit .default_acquisition_optimizer_factory ,
477
480
)
481
+ _gp_model_class : sp .ModelCoroutine [tfd .GaussianProcess ] = attr .field (
482
+ kw_only = True ,
483
+ factory = lambda : tuned_gp_models .VizierGaussianProcess ,
484
+ )
478
485
_metadata_ns : str = attr .field (
479
486
default = 'google_gp_ucb_pe_bandit' , kw_only = True
480
487
)
@@ -611,7 +618,7 @@ def _build_gp_model_and_optimize_parameters(
611
618
`data.labels`. If `data.features` is empty, the returned parameters are
612
619
initial values picked by the GP model.
613
620
"""
614
- coroutine = tuned_gp_models . VizierGaussianProcess .build_model (
621
+ coroutine = self . _gp_model_class .build_model ( # pytype: disable=attribute-error
615
622
data .features
616
623
).coroutine
617
624
model = sp .CoroutineWithData (coroutine , data )
0 commit comments