Skip to content

Commit 8f1bdfe

Browse files
vizier-teamcopybara-github
authored andcommitted
Allows dependency injection of the GP model into the GP_UCB_PE designer
PiperOrigin-RevId: 705520724
1 parent bb140c4 commit 8f1bdfe

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vizier/_src/algorithms/designers/gp_ucb_pe.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ class VizierGPUCBPEBandit(vza.Designer):
457457
Attributes:
458458
problem: Must be a flat study with a single metric.
459459
acquisition_optimizer:
460+
gp_model_class: The GP model class, which must implement a `build_model`
461+
class method that takes `ModelInput` and returns a
462+
`StochasticProcessModel`.
460463
metadata_ns: Metadata namespace that this designer writes to.
461464
use_trust_region: Uses trust region.
462465
ard_optimizer: An optimizer object, which should return a batch of
@@ -475,6 +478,10 @@ class VizierGPUCBPEBandit(vza.Designer):
475478
kw_only=True,
476479
factory=lambda: VizierGPUCBPEBandit.default_acquisition_optimizer_factory,
477480
)
481+
_gp_model_class: sp.ModelCoroutine[tfd.GaussianProcess] = attr.field(
482+
kw_only=True,
483+
factory=lambda: tuned_gp_models.VizierGaussianProcess,
484+
)
478485
_metadata_ns: str = attr.field(
479486
default='google_gp_ucb_pe_bandit', kw_only=True
480487
)
@@ -611,7 +618,7 @@ def _build_gp_model_and_optimize_parameters(
611618
`data.labels`. If `data.features` is empty, the returned parameters are
612619
initial values picked by the GP model.
613620
"""
614-
coroutine = tuned_gp_models.VizierGaussianProcess.build_model(
621+
coroutine = self._gp_model_class.build_model( # pytype: disable=attribute-error
615622
data.features
616623
).coroutine
617624
model = sp.CoroutineWithData(coroutine, data)

0 commit comments

Comments
 (0)