Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vizier/_src/algorithms/designers/gp/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,13 +569,23 @@ class ScalarizedAcquisition(AcquisitionFunction):
reduction_fn: Callable[[jax.Array], jax.Array] = struct.field(
pytree_node=False, default=lambda x: x
)
max_scalarized: Optional[jax.Array] = struct.field(
pytree_node=False, default=None
)

def __call__(
self,
dist: tfd.Distribution,
seed: Optional[jax.Array] = None,
) -> jax.Array:
scalarized = self.scalarizer(self.acquisition_fn(dist, seed).squeeze())
# Broadcast max_scalarized to the same shape as scalarized and take max.
if self.max_scalarized is not None:
shape_mismatch = len(scalarized.shape) - len(self.max_scalarized.shape)
expand_max = jnp.expand_dims(
self.max_scalarized, axis=range(-shape_mismatch, 0)
)
scalarized = jnp.maximum(scalarized, expand_max)
return self.reduction_fn(scalarized)


Expand Down
16 changes: 14 additions & 2 deletions vizier/_src/algorithms/designers/gp/acquisitions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,24 @@ def test_scalarized_ucb(self):
jnp.array([[0.2, 0.3], [0.01, 0.5], [0.5, 0.01]])
)
reference_point = acquisitions.get_worst_labels(labels)
ucb = acquisitions.UCB(coefficient=2.0)
ucb = acquisitions.UCB(coefficient=0.1)
scalarizer = scalarization.HyperVolumeScalarization(
weights=jnp.array([0.1, 0.2]), reference_point=reference_point
)

acq = acquisitions.ScalarizedAcquisition(ucb, scalarizer)
self.assertAlmostEqual(
acq(tfd.Normal([0.1, 0.2], [1, 2])), jnp.array(436.81), delta=1e-2
acq(tfd.Normal([0.1, 0.2], [0.1, 0.1])), jnp.array([1.0]), delta=1e-2
)

# Tests that the scalarized acquisition is larger with max_scalarized.
scalarized_labels = scalarizer(labels.unpad())
max_scalarized = jnp.max(scalarized_labels, axis=-1)
acq = acquisitions.ScalarizedAcquisition(
ucb, scalarizer, max_scalarized=max_scalarized
)
self.assertAlmostEqual(
acq(tfd.Normal([0.1, 0.2], [0.1, 0.1])), jnp.array([2.10]), delta=1e-2
)

def test_ehvi_approximation(self):
Expand All @@ -101,6 +112,7 @@ def test_ehvi_approximation(self):
acquisitions.UCB(coefficient=0.0),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
max_scalarized=jnp.zeros(shape=(num_scalarizations,)),
)
# Expected hypervolume should be 2 * 1.5 = 3.0.
np.testing.assert_allclose(
Expand Down
22 changes: 20 additions & 2 deletions vizier/_src/algorithms/designers/gp_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,29 +602,47 @@ def from_problem(

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Scalarized UCB.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling),
acq_lib.get_reference_point(data.labels, reference_scaling)
if has_labels
else None,
)

max_scalarized = (
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acq_lib.UCB(),
scalarizer,
reduction_fn=lambda x: jnp.mean(x, axis=0),
max_scalarized=max_scalarized,
)

else:

def acq_fn_factory(data: types.ModelData) -> acq_lib.AcquisitionFunction:
# Sampled EHVI.
labels_array = data.labels.padded_array
has_labels = labels_array.shape[0] > 0
scalarizer = scalarization.HyperVolumeScalarization(
weights,
acq_lib.get_reference_point(data.labels, reference_scaling),
acq_lib.get_reference_point(data.labels, reference_scaling)
if has_labels
else None,
)

max_scalarized = (
jnp.max(scalarizer(labels_array), axis=-1) if has_labels else None
)
return acq_lib.ScalarizedAcquisition(
acq_lib.Sample(num_samples),
scalarizer,
# We need to reduce across the scalarization and sample axes.
reduction_fn=lambda x: jnp.mean(jax.nn.relu(x), axis=[0, 1]),
max_scalarized=max_scalarized,
)

scoring_function_factory = acq_lib.bayesian_scoring_function_factory(
Expand Down
2 changes: 1 addition & 1 deletion vizier/_src/algorithms/designers/gp_bandit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _qei_factory(data: types.ModelData) -> acquisitions.AcquisitionFunction:
)

@parameterized.parameters(
dict(num_samples=10),
dict(num_samples=11),
dict(num_samples=None),
)
def test_multi_metrics(self, num_samples: int | None):
Expand Down