Skip to content

Commit 5930dc1

Browse files
xingyousongcopybara-github
authored andcommitted
Small update
PiperOrigin-RevId: 629803440
1 parent c48220d commit 5930dc1

File tree

1 file changed

+29
-25
lines changed

1 file changed

+29
-25
lines changed

vizier/_src/benchmarks/experimenters/shifting_experimenter.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@ def __init__(
5454
exptr_problem_statement = exptr.problem_statement()
5555

5656
if exptr_problem_statement.search_space.is_conditional:
57-
raise ValueError('Search space should not have conditional'
58-
f' parameters {exptr_problem_statement}')
57+
raise ValueError(
58+
'Search space should not have conditional'
59+
f' parameters {exptr_problem_statement}'
60+
)
5961
dimension = len(exptr_problem_statement.search_space.parameters)
6062
if dimension <= 0:
6163
raise ValueError(f'Invalid dimension: {dimension}')
@@ -64,8 +66,8 @@ def __init__(
6466
self._shift = np.broadcast_to(shift, (dimension,))
6567
except ValueError as broadcast_err:
6668
raise ValueError(
67-
f'Shift {shift} is not broadcastable for dim: {dimension}.'
68-
'\n') from broadcast_err
69+
f'Shift {shift} is not broadcastable for dim: {dimension}.\n'
70+
) from broadcast_err
6971

7072
# Converter should be in the underlying extpr space.
7173
self._converter = converters.TrialToArrayConverter.from_study_config(
@@ -83,26 +85,27 @@ def __init__(
8385
):
8486
if parameter.type != pyvizier.ParameterType.DOUBLE:
8587
raise ValueError(f'Non-double parameters {parameter}')
86-
if (bounds := parameter.bounds) is not None:
87-
if abs(shift) >= bounds[1] - bounds[0]:
88-
raise ValueError(
89-
f'Bounds {bounds} may need to be extended'
90-
f'as shift {shift} is too large '
91-
)
92-
# Shift the bounds to maintain valid bounds.
93-
if shift >= 0:
94-
new_bounds = (bounds[0] + shift, bounds[1])
95-
else:
96-
new_bounds = (bounds[0], bounds[1] + shift)
97-
self._problem_statement.search_space.add(
98-
pyvizier.ParameterConfig.factory(
99-
name=parameter.name,
100-
bounds=new_bounds,
101-
scale_type=parameter.scale_type,
102-
default_value=parameter.default_value,
103-
external_type=parameter.external_type,
104-
)
88+
89+
bounds = parameter.bounds
90+
if abs(shift) >= bounds[1] - bounds[0]:
91+
raise ValueError(
92+
f'Bounds {bounds} may need to be extended'
93+
f'as shift {shift} is too large '
10594
)
95+
# Shift the bounds to maintain valid bounds.
96+
if shift >= 0:
97+
new_bounds = (bounds[0] + shift, bounds[1])
98+
else:
99+
new_bounds = (bounds[0], bounds[1] + shift)
100+
self._problem_statement.search_space.add(
101+
pyvizier.ParameterConfig.factory(
102+
name=parameter.name,
103+
bounds=new_bounds,
104+
scale_type=parameter.scale_type,
105+
default_value=parameter.default_value,
106+
external_type=parameter.external_type,
107+
),
108+
)
106109

107110
def problem_statement(self) -> pyvizier.ProblemStatement:
108111
return copy.deepcopy(self._problem_statement)
@@ -116,8 +119,9 @@ def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None:
116119
for parameters, suggestion in zip(previous_parameters, suggestions):
117120
suggestion.parameters = parameters
118121

119-
def _offset(self, suggestions: Sequence[pyvizier.Trial],
120-
shift: np.ndarray) -> None:
122+
def _offset(
123+
self, suggestions: Sequence[pyvizier.Trial], shift: np.ndarray
124+
) -> None:
121125
"""Offsets parameter values (OOB values are clipped)."""
122126
for suggestion in suggestions:
123127
features = self._converter.to_features([suggestion])

0 commit comments

Comments
 (0)