Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions vizier/_src/pythia/singleton_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

"""A helper functionality to handle singleton parameters."""

import copy
import logging
from typing import Sequence

import attrs
from vizier import pyvizier as vz


@attrs.define
class SingletonParameterHandler:
"""A helper class to handle singleton parameters.

The class allows to remove singleton parameters (i.e. having a single value)
from the problem statement's search space and from trial, so that designers
won't have to handle them. In addition, it allows to re-add the singleton
values back to the trial suggestions, so that the users can access them if
needed.
"""

problem: vz.ProblemStatement
# ----------------------------------------------------------------------------
# Internal attributes
# ----------------------------------------------------------------------------
_singletons: dict[str, vz.ParameterValueTypes] = attrs.field(init=False)
_stripped_problem: vz.ProblemStatement = attrs.field(init=False)

@property
def stripped_problem(self) -> vz.ProblemStatement:
"""Returns the stripped problem."""
return self._stripped_problem

def __attrs_post_init__(self):
logging.info("problem: %s", self.problem)
self._singletons = self._find_singletons()
self._stripped_problem = self._strip_problem()

def _find_singletons(self) -> dict[str, vz.ParameterValueTypes]:
"""Finds the singleton parameters in the problem."""
singletons = {}
for param in self.problem.search_space.parameters:
if param.type == vz.ParameterType.DOUBLE:
if param.bounds[0] == param.bounds[1]:
singletons[param.name] = param.bounds[0]
elif param.type == vz.ParameterType.INTEGER:
if param.bounds[0] == param.bounds[1]:
singletons[param.name] = param.bounds[0]
elif param.type in (
vz.ParameterType.CATEGORICAL,
vz.ParameterType.DISCRETE,
):
if len(param.feasible_values) == 1:
singletons[param.name] = param.feasible_values[0]
elif param.type == vz.ParameterType.CUSTOM:
pass
else:
raise ValueError("Unknown parameter type: %s" % param.type)
return singletons

def _strip_problem(self) -> vz.ProblemStatement:
"""Strips the problem of the singleton parameters."""
stripped_problem = copy.deepcopy(self.problem)
for param in self.problem.search_space.parameters:
if param.name in self._singletons:
stripped_problem.search_space.pop(param.name)
return stripped_problem

def strip_trials(
self, trials: Sequence[vz.TrialSuggestion]
) -> Sequence[vz.TrialSuggestion]:
"""Strips the trials of the singleton parameters."""
if not self._singletons:
return trials
new_trials = []
for trial in trials:
new_trial = copy.deepcopy(trial)
for param_name in trial.parameters:
if param_name in self._singletons:
del new_trial.parameters[param_name]
new_trials.append(new_trial)
return new_trials

def augment_trials(
self, trials: Sequence[vz.TrialSuggestion]
) -> Sequence[vz.TrialSuggestion]:
"""Augments the trials with the singleton parameters."""
if not self._singletons:
return trials
new_trials = []
for trial in trials:
new_trial = copy.deepcopy(trial)
for singleton_name, singleton_value in self._singletons.items():
new_trial.parameters[singleton_name] = singleton_value
new_trials.append(new_trial)
return new_trials
186 changes: 186 additions & 0 deletions vizier/_src/pythia/singleton_params_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

"""Tests for singleton_params module which allows to handle singleton parameters."""

from vizier import pyvizier as vz
from vizier._src.pythia import singleton_params
from absl.testing import absltest
from absl.testing import parameterized


def get_problem_with_singletons() -> vz.ProblemStatement:
"""Returns a problem with singleton parameters."""
problem = vz.ProblemStatement()
problem.search_space.root.add_float_param('sf', min_value=1, max_value=1)
problem.search_space.root.add_int_param('si', min_value=5, max_value=5)
problem.search_space.root.add_categorical_param('sc', feasible_values=['a'])
problem.search_space.root.add_discrete_param('sd', feasible_values=[3])

problem.search_space.root.add_float_param('f', min_value=0.0, max_value=5.0)
problem.search_space.root.add_int_param('i', min_value=0, max_value=10)
problem.search_space.root.add_categorical_param(
'c', feasible_values=['a', '1']
)
problem.search_space.root.add_discrete_param('d', feasible_values=[3, 1])
problem.metric_information = [
vz.MetricInformation(name='obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE)
]
return problem


def get_problem_without_singletons() -> vz.ProblemStatement:
"""Returns a problem with singleton parameters."""
problem = vz.ProblemStatement()
problem.search_space.root.add_float_param('f', min_value=0.0, max_value=5.0)
problem.search_space.root.add_int_param('i', min_value=0, max_value=10)
problem.search_space.root.add_categorical_param(
'c', feasible_values=['a', '1']
)
problem.search_space.root.add_discrete_param('d', feasible_values=[3, 1])
problem.metric_information = [
vz.MetricInformation(name='obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE)
]
return problem


class SingletonParameterHandlerTest(parameterized.TestCase):
"""Tests for the `singletonParameterHandler`."""

@parameterized.named_parameters(
('with_singletons', get_problem_with_singletons()),
('without_singletons', get_problem_without_singletons()),
)
def test_stripped_problem(self, problem):
handler = singleton_params.SingletonParameterHandler(problem)
expected_search_search = get_problem_without_singletons().search_space
self.assertEqual(
handler.stripped_problem.search_space, expected_search_search
)

def test_strip_trials(self):
problem = get_problem_with_singletons()
handler = singleton_params.SingletonParameterHandler(problem)
trials = [
vz.TrialSuggestion(
parameters={
'sf': 1.0,
'si': 5,
'sc': 'a',
'sd': 3,
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
]
stripped_trials = handler.strip_trials(trials)
self.assertEqual(
stripped_trials,
[
vz.TrialSuggestion(
parameters={
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
],
)

def test_augment_trials(self):
problem = get_problem_with_singletons()
handler = singleton_params.SingletonParameterHandler(problem)
trials = [
vz.TrialSuggestion(
parameters={
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
]
augmented_trials = handler.augment_trials(trials)
self.assertEqual(
augmented_trials,
[
vz.TrialSuggestion(
parameters={
'sf': 1.0,
'si': 5,
'sc': 'a',
'sd': 3,
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
],
)

def test_find_singletons(self):
problem = get_problem_with_singletons()
handler = singleton_params.SingletonParameterHandler(problem)
self.assertEqual(
handler._singletons,
{
'sf': 1.0,
'si': 5,
'sc': 'a',
'sd': 3,
},
)

def test_apply_strip_multiple_times(self):
problem = get_problem_with_singletons()
handler = singleton_params.SingletonParameterHandler(problem)
trials = [
vz.TrialSuggestion(
parameters={
'sf': 1.0,
'si': 5,
'sc': 'a',
'sd': 3,
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
]
for _ in range(5):
trials = handler.strip_trials(trials)
self.assertEqual(
trials,
[
vz.TrialSuggestion(
parameters={
'f': 1.0,
'i': 1,
'c': 'a',
'd': 3,
}
)
],
)


if __name__ == '__main__':
absltest.main()
Loading