Skip to content

Commit 6f36d29

Browse files
committed
Fix Optuna tests
1 parent 5316709 commit 6f36d29

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

pkg/suggestion/v1beta1/goptuna/converter.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ func toGoptunaSampler(algorithm *api_v1_beta1.AlgorithmSpec) (goptuna.Sampler, g
7575
return nil, nil, err
7676
}
7777
opts = append(opts, tpe.SamplerOptionSeed(int64(seed)))
78-
} else if s.Name == "startup_trials" {
78+
} else if s.Name == "n_startup_trials" {
7979
n, err := strconv.Atoi(s.Value)
8080
if err != nil {
8181
return nil, nil, err
8282
}
8383
opts = append(opts, tpe.SamplerOptionNumberOfStartupTrials(n))
84-
} else if s.Name == "ei_candidates" {
84+
} else if s.Name == "n_ei_candidates" {
8585
n, err := strconv.Atoi(s.Value)
8686
if err != nil {
8787
return nil, nil, err

test/suggestion/v1beta1/test_optuna_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def setup_method(self):
3333
servicers, grpc_testing.strict_real_time())
3434

3535
@pytest.mark.parametrize(
36-
["algorithm_name", "algorithm_settings"],
36+
["algorithm_name", "algorithm_settings"],
3737
[
38-
["tpe", {"startup_trials": "20", "ei_candidates": "10", "random_state": "71"}],
39-
["multivariate-tpe", {"startup_trials": "20", "ei_candidates": "10", "random_state": "71"}],
38+
["tpe", {"n_startup_trials": "20", "n_ei_candidates": "10", "random_state": "71"}],
39+
["multivariate-tpe", {"n_startup_trials": "20", "n_ei_candidates": "10", "random_state": "71"}],
4040
["cmaes", {"restart_strategy": "ipop", "sigma": "2", "random_state": "71"}],
4141
["random", {"random_state": "71"}],
4242
],
@@ -47,7 +47,7 @@ def test_get_suggestion(self, algorithm_name, algorithm_settings):
4747
spec=api_pb2.ExperimentSpec(
4848
algorithm=api_pb2.AlgorithmSpec(
4949
algorithm_name=algorithm_name,
50-
algorithm_settings = [
50+
algorithm_settings=[
5151
api_pb2.AlgorithmSetting(
5252
name=name,
5353
value=value

0 commit comments

Comments
 (0)