Skip to content

Commit d9e09e1

Browse files
xingyousongcopybara-github
authored andcommitted
Fix accidental bug in MultiObjectiveNumpyExperimenter
PiperOrigin-RevId: 646546642
1 parent 7265b64 commit d9e09e1

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

vizier/_src/benchmarks/experimenters/numpy_experimenter.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import Callable, Sequence
2323

2424
import numpy as np
25-
from vizier import pyvizier
25+
from vizier import pyvizier as vz
2626
from vizier._src.benchmarks.experimenters import experimenter
2727
from vizier.pyvizier import converters
2828

@@ -43,7 +43,7 @@ class NumpyExperimenter(experimenter.Experimenter):
4343
def __init__(
4444
self,
4545
impl: Callable[[np.ndarray], float],
46-
problem_statement: pyvizier.ProblemStatement,
46+
problem_statement: vz.ProblemStatement,
4747
):
4848
"""NumpyExperimenter with analytic function impl for one metric.
4949
@@ -80,7 +80,7 @@ def __init__(
8080
raise ValueError(f'Non-numeric parameters {parameter}')
8181

8282
objective_metrics = problem_statement.metric_information.of_type(
83-
pyvizier.MetricType.OBJECTIVE
83+
vz.MetricType.OBJECTIVE
8484
)
8585
self._metric_name = objective_metrics.item().name
8686

@@ -91,24 +91,22 @@ def __init__(
9191
flip_sign_for_minimization_metrics=False,
9292
)
9393

94-
def problem_statement(self) -> pyvizier.ProblemStatement:
94+
def problem_statement(self) -> vz.ProblemStatement:
9595
return copy.deepcopy(self._problem_statement)
9696

97-
def evaluate(self, suggestions: Sequence[pyvizier.Trial]):
97+
def evaluate(self, suggestions: Sequence[vz.Trial]):
9898
# Features has shape (num_trials, num_features).
9999
features = self._converter.to_features(suggestions)
100100
for idx, suggestion in enumerate(suggestions):
101101
val = self.impl(features[idx])
102102
if math.isfinite(val):
103-
suggestion.complete(
104-
pyvizier.Measurement(metrics={self._metric_name: val})
105-
)
103+
suggestion.complete(vz.Measurement(metrics={self._metric_name: val}))
106104
else:
107105
self.problem_statement().search_space.assert_contains(
108106
suggestions[idx].parameters
109107
)
110108
suggestion.complete(
111-
pyvizier.Measurement(),
109+
vz.Measurement(),
112110
infeasibility_reason='Objective value is not finite: %f' % val,
113111
)
114112

@@ -126,7 +124,7 @@ class MultiObjectiveNumpyExperimenter(experimenter.Experimenter):
126124
def __init__(
127125
self,
128126
impl: Callable[[np.ndarray], Sequence[float]],
129-
problem_statement: pyvizier.ProblemStatement,
127+
problem_statement: vz.ProblemStatement,
130128
):
131129
self._impl = impl
132130
self._problem_statement = copy.deepcopy(problem_statement)
@@ -136,13 +134,14 @@ def __init__(
136134
flip_sign_for_minimization_metrics=False,
137135
)
138136

139-
def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None:
137+
def evaluate(self, suggestions: Sequence[vz.Trial]) -> None:
138+
metric_info = self._problem_statement.metric_information
140139
features = self._converter.to_features(suggestions)
141140
for i, suggestion in enumerate(suggestions):
142141
feat = features[i]
143142
values = self._impl(feat)
144-
for mc, value in zip(self._problem_statement.metric_information, values):
145-
suggestion.complete(pyvizier.Measurement(metrics={mc.name: value}))
143+
metrics = {mc.name: value for mc, value in zip(metric_info, values)}
144+
suggestion.complete(vz.Measurement(metrics))
146145

147-
def problem_statement(self) -> pyvizier.ProblemStatement:
146+
def problem_statement(self) -> vz.ProblemStatement:
148147
return copy.deepcopy(self._problem_statement)

0 commit comments

Comments
 (0)