@@ -54,8 +54,10 @@ def __init__(
5454 exptr_problem_statement = exptr .problem_statement ()
5555
5656 if exptr_problem_statement .search_space .is_conditional :
57- raise ValueError ('Search space should not have conditional'
58- f' parameters { exptr_problem_statement } ' )
57+ raise ValueError (
58+ 'Search space should not have conditional'
59+ f' parameters { exptr_problem_statement } '
60+ )
5961 dimension = len (exptr_problem_statement .search_space .parameters )
6062 if dimension <= 0 :
6163 raise ValueError (f'Invalid dimension: { dimension } ' )
@@ -64,8 +66,8 @@ def __init__(
6466 self ._shift = np .broadcast_to (shift , (dimension ,))
6567 except ValueError as broadcast_err :
6668 raise ValueError (
67- f'Shift { shift } is not broadcastable for dim: { dimension } .'
68- ' \n ' ) from broadcast_err
69+ f'Shift { shift } is not broadcastable for dim: { dimension } .\n '
70+ ) from broadcast_err
6971
7072 # Converter should be in the underlying extpr space.
7173 self ._converter = converters .TrialToArrayConverter .from_study_config (
@@ -83,26 +85,27 @@ def __init__(
8385 ):
8486 if parameter .type != pyvizier .ParameterType .DOUBLE :
8587 raise ValueError (f'Non-double parameters { parameter } ' )
86- if (bounds := parameter .bounds ) is not None :
87- if abs (shift ) >= bounds [1 ] - bounds [0 ]:
88- raise ValueError (
89- f'Bounds { bounds } may need to be extended'
90- f'as shift { shift } is too large '
91- )
92- # Shift the bounds to maintain valid bounds.
93- if shift >= 0 :
94- new_bounds = (bounds [0 ] + shift , bounds [1 ])
95- else :
96- new_bounds = (bounds [0 ], bounds [1 ] + shift )
97- self ._problem_statement .search_space .add (
98- pyvizier .ParameterConfig .factory (
99- name = parameter .name ,
100- bounds = new_bounds ,
101- scale_type = parameter .scale_type ,
102- default_value = parameter .default_value ,
103- external_type = parameter .external_type ,
104- )
88+
89+ bounds = parameter .bounds
90+ if abs (shift ) >= bounds [1 ] - bounds [0 ]:
91+ raise ValueError (
92+ f'Bounds { bounds } may need to be extended'
93+ f'as shift { shift } is too large '
10594 )
95+ # Shift the bounds to maintain valid bounds.
96+ if shift >= 0 :
97+ new_bounds = (bounds [0 ] + shift , bounds [1 ])
98+ else :
99+ new_bounds = (bounds [0 ], bounds [1 ] + shift )
100+ self ._problem_statement .search_space .add (
101+ pyvizier .ParameterConfig .factory (
102+ name = parameter .name ,
103+ bounds = new_bounds ,
104+ scale_type = parameter .scale_type ,
105+ default_value = parameter .default_value ,
106+ external_type = parameter .external_type ,
107+ ),
108+ )
106109
107110 def problem_statement (self ) -> pyvizier .ProblemStatement :
108111 return copy .deepcopy (self ._problem_statement )
@@ -116,8 +119,9 @@ def evaluate(self, suggestions: Sequence[pyvizier.Trial]) -> None:
116119 for parameters , suggestion in zip (previous_parameters , suggestions ):
117120 suggestion .parameters = parameters
118121
119- def _offset (self , suggestions : Sequence [pyvizier .Trial ],
120- shift : np .ndarray ) -> None :
122+ def _offset (
123+ self , suggestions : Sequence [pyvizier .Trial ], shift : np .ndarray
124+ ) -> None :
121125 """Offsets parameter values (OOB values are clipped)."""
122126 for suggestion in suggestions :
123127 features = self ._converter .to_features ([suggestion ])
0 commit comments