Skip to content

Commit 0169033

Browse files
committed
Implement Goptuna based suggestion service
1 parent 71a59ff commit 0169033

File tree

10 files changed

+704
-0
lines changed

10 files changed

+704
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Currently Katib supports the following exploration algorithms:
9292
* [Grid Search](https://en.wikipedia.org/wiki/Hyperparameter_optimization#Grid_search)
9393
* [Hyperband](https://arxiv.org/pdf/1603.06560.pdf)
9494
* [Bayesian Optimization](https://arxiv.org/pdf/1012.2599.pdf)
95+
* [CMA Evolution Strategy](https://arxiv.org/abs/1604.00772)
9596

9697
#### Neural Architecture Search
9798

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
FROM golang:alpine AS go-build
2+
# The GOPATH in the image is /go.
3+
ADD . /go/src/github.com/kubeflow/katib
4+
WORKDIR /go/src/github.com/kubeflow/katib/cmd/suggestion/goptuna
5+
RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \
6+
apk --update add gcc musl-dev && \
7+
go build -o goptuna-suggestion ./v1alpha3; \
8+
else \
9+
go build -o goptuna-suggestion ./v1alpha3; \
10+
fi
11+
12+
RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \
13+
if [ "$(uname -m)" = "ppc64le" ]; then \
14+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \
15+
elif [ "$(uname -m)" = "aarch64" ]; then \
16+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \
17+
else \
18+
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \
19+
fi && \
20+
chmod +x /bin/grpc_health_probe
21+
22+
FROM alpine:3.7
23+
WORKDIR /app
24+
COPY --from=go-build /bin/grpc_health_probe /bin/
25+
COPY --from=go-build /go/src/github.com/kubeflow/katib/cmd/suggestion/goptuna/goptuna-suggestion /app/
26+
27+
ENTRYPOINT ["./goptuna-suggestion"]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"net"
6+
7+
health_pb "github.com/kubeflow/katib/pkg/apis/manager/health"
8+
"github.com/kubeflow/katib/pkg/apis/manager/v1alpha3"
9+
suggestion "github.com/kubeflow/katib/pkg/suggestion/v1alpha3/goptuna"
10+
"google.golang.org/grpc"
11+
"k8s.io/klog"
12+
)
13+
14+
const (
15+
address = "0.0.0.0:6789"
16+
recvMsgSize = 1<<31 - 1
17+
sendMsgSize = 1<<31 - 1
18+
)
19+
20+
type healthService struct {
21+
}
22+
23+
func (s *healthService) Check(ctx context.Context, in *health_pb.HealthCheckRequest) (*health_pb.HealthCheckResponse, error) {
24+
return &health_pb.HealthCheckResponse{
25+
Status: health_pb.HealthCheckResponse_SERVING,
26+
}, nil
27+
}
28+
29+
func main() {
30+
l, err := net.Listen("tcp", address)
31+
if err != nil {
32+
klog.Fatalf("Failed to listen: %v", err)
33+
}
34+
srv := grpc.NewServer(grpc.MaxRecvMsgSize(recvMsgSize), grpc.MaxSendMsgSize(sendMsgSize))
35+
api_v1_alpha3.RegisterSuggestionServer(srv, suggestion.NewSuggestionService())
36+
health_pb.RegisterHealthServer(srv, &healthService{})
37+
38+
klog.Infof("Start Goptuna suggestion service: %s", address)
39+
err = srv.Serve(l)
40+
if err != nil {
41+
klog.Fatalf("Failed to serve: %v", err)
42+
}
43+
return
44+
}

docs/proposals/suggestion.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,3 +396,7 @@ We can use [hyperopt](https://github.com/hyperopt/hyperopt) to run Anneal.
396396
### SMAC
397397
398398
We can use [SMAC3](https://github.com/automl/SMAC3) to run SMAC.
399+
400+
### CMA-ES
401+
402+
We can use [goptuna](https://github.com/c-bata/goptuna) to run CMA-ES.

manifests/v1alpha3/katib-controller/katib-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ data:
3838
"tpe": {
3939
"image": "gcr.io/kubeflow-images-public/katib/v1alpha3/suggestion-hyperopt"
4040
},
41+
"cmaes": {
42+
"image": "gcr.io/kubeflow-images-public/katib/v1alpha3/suggestion-goptuna"
43+
},
4144
"nasrl": {
4245
"image": "gcr.io/kubeflow-images-public/katib/v1alpha3/suggestion-nasrl",
4346
"imagePullPolicy": "Always",
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package suggestion_goptuna_v1alpha3
2+
3+
import (
4+
"errors"
5+
"strconv"
6+
"time"
7+
8+
"github.com/c-bata/goptuna"
9+
"github.com/c-bata/goptuna/cmaes"
10+
"github.com/c-bata/goptuna/tpe"
11+
api_v1_alpha3 "github.com/kubeflow/katib/pkg/apis/manager/v1alpha3"
12+
)
13+
14+
func toGoptunaDirection(t api_v1_alpha3.ObjectiveType) (goptuna.StudyDirection, error) {
15+
if t == api_v1_alpha3.ObjectiveType_MINIMIZE {
16+
return goptuna.StudyDirectionMinimize, nil
17+
} else if t == api_v1_alpha3.ObjectiveType_MAXIMIZE {
18+
return goptuna.StudyDirectionMaximize, nil
19+
}
20+
return "", errors.New("unexpected objective type")
21+
}
22+
23+
func toGoptunaSampler(algorithm *api_v1_alpha3.AlgorithmSpec) (goptuna.Sampler, goptuna.RelativeSampler, error) {
24+
if algorithm == nil {
25+
return nil, nil, errors.New("empty algorithm spec")
26+
}
27+
28+
name := algorithm.GetAlgorithmName()
29+
if name == AlgorithmCMAES {
30+
opts := make([]cmaes.SamplerOption, 0, len(algorithm.GetAlgorithmSetting()))
31+
for _, s := range algorithm.GetAlgorithmSetting() {
32+
if s.Name == "random_state" {
33+
seed, err := strconv.Atoi(s.Value)
34+
if err != nil {
35+
return nil, nil, err
36+
}
37+
opts = append(opts, cmaes.SamplerOptionSeed(int64(seed)))
38+
}
39+
}
40+
return nil, cmaes.NewSampler(opts...), nil
41+
} else if name == AlgorithmTPE {
42+
opts := make([]tpe.SamplerOption, 0, len(algorithm.GetAlgorithmSetting()))
43+
for _, s := range algorithm.GetAlgorithmSetting() {
44+
if s.Name == "random_state" {
45+
seed, err := strconv.Atoi(s.Value)
46+
if err != nil {
47+
return nil, nil, err
48+
}
49+
opts = append(opts, tpe.SamplerOptionSeed(int64(seed)))
50+
}
51+
}
52+
return tpe.NewSampler(opts...), nil, nil
53+
} else if name == AlgorithmRandom {
54+
opts := make([]goptuna.RandomSearchSamplerOption, 0, len(algorithm.GetAlgorithmSetting()))
55+
for _, s := range algorithm.GetAlgorithmSetting() {
56+
if s.Name == "random_state" {
57+
seed, err := strconv.Atoi(s.Value)
58+
if err != nil {
59+
return nil, nil, err
60+
}
61+
opts = append(opts, goptuna.RandomSearchSamplerOptionSeed(int64(seed)))
62+
}
63+
}
64+
return goptuna.NewRandomSearchSampler(opts...), nil, nil
65+
}
66+
return nil, nil, errors.New("invalid algorithm name")
67+
}
68+
69+
func toGoptunaSearchSpace(parameters []*api_v1_alpha3.ParameterSpec) (map[string]interface{}, error) {
70+
if parameters == nil {
71+
return nil, errors.New("empty search space")
72+
}
73+
74+
searchSpace := make(map[string]interface{}, len(parameters))
75+
for _, p := range parameters {
76+
if p.ParameterType == api_v1_alpha3.ParameterType_UNKNOWN_TYPE {
77+
return nil, errors.New("invalid parameter type")
78+
}
79+
80+
if p.ParameterType == api_v1_alpha3.ParameterType_DOUBLE {
81+
high, err := strconv.ParseFloat(p.GetFeasibleSpace().GetMax(), 64)
82+
if err != nil {
83+
return nil, err
84+
}
85+
low, err := strconv.ParseFloat(p.GetFeasibleSpace().GetMin(), 64)
86+
if err != nil {
87+
return nil, err
88+
}
89+
searchSpace[p.Name] = goptuna.UniformDistribution{
90+
High: high,
91+
Low: low,
92+
}
93+
} else if p.ParameterType == api_v1_alpha3.ParameterType_INT {
94+
high, err := strconv.Atoi(p.GetFeasibleSpace().GetMax())
95+
if err != nil {
96+
return nil, err
97+
}
98+
low, err := strconv.Atoi(p.GetFeasibleSpace().GetMin())
99+
if err != nil {
100+
return nil, err
101+
}
102+
searchSpace[p.Name] = goptuna.IntUniformDistribution{
103+
High: high,
104+
Low: low,
105+
}
106+
} else if p.ParameterType == api_v1_alpha3.ParameterType_CATEGORICAL {
107+
choices := p.GetFeasibleSpace().GetList()
108+
searchSpace[p.Name] = goptuna.CategoricalDistribution{
109+
Choices: choices,
110+
}
111+
} else if p.ParameterType == api_v1_alpha3.ParameterType_DISCRETE {
112+
// Use categorical distribution instead of goptuna.DiscreteUniformDistribution
113+
// because goptuna.UniformDistributions needs to declare the parameter space
114+
// with minimum value, maximum value and interval.
115+
choices := p.GetFeasibleSpace().GetList()
116+
searchSpace[p.Name] = goptuna.CategoricalDistribution{
117+
Choices: choices,
118+
}
119+
} else {
120+
return nil, errors.New("unsupported parameter type")
121+
}
122+
}
123+
return searchSpace, nil
124+
}
125+
126+
func toGoptunaState(condition api_v1_alpha3.TrialStatus_TrialConditionType) (goptuna.TrialState, error) {
127+
if condition == api_v1_alpha3.TrialStatus_CREATED {
128+
return goptuna.TrialStateWaiting, nil
129+
} else if condition == api_v1_alpha3.TrialStatus_RUNNING {
130+
return goptuna.TrialStateRunning, nil
131+
} else if condition == api_v1_alpha3.TrialStatus_SUCCEEDED {
132+
return goptuna.TrialStateComplete, nil
133+
} else if condition == api_v1_alpha3.TrialStatus_KILLED {
134+
return goptuna.TrialStateFail, nil
135+
} else if condition == api_v1_alpha3.TrialStatus_FAILED {
136+
return goptuna.TrialStateFail, nil
137+
}
138+
return goptuna.TrialStateFail, errors.New("unexpected trial condition")
139+
}
140+
141+
func toGoptunaTrials(
142+
ktrials []*api_v1_alpha3.Trial,
143+
study *goptuna.Study,
144+
searchSpace map[string]interface{},
145+
) ([]goptuna.FrozenTrial, error) {
146+
gtrials := make([]goptuna.FrozenTrial, 0, len(ktrials))
147+
for i, kt := range ktrials {
148+
datetimeStart, err := time.Parse(time.RFC3339Nano, kt.GetStatus().GetStartTime())
149+
if err != nil {
150+
return nil, err
151+
}
152+
datetimeComplete, err := time.Parse(time.RFC3339Nano, kt.GetStatus().GetCompletionTime())
153+
if err != nil {
154+
return nil, err
155+
}
156+
state, err := toGoptunaState(kt.GetStatus().GetCondition())
157+
if err != nil {
158+
return nil, err
159+
}
160+
161+
metrics := kt.GetStatus().GetObservation().GetMetrics()
162+
intermediateValues := make(map[int]float64, len(metrics))
163+
var finalValue float64
164+
for i, m := range metrics {
165+
v, err := strconv.ParseFloat(m.GetValue(), 64)
166+
if err != nil {
167+
return nil, err
168+
}
169+
intermediateValues[i] = v
170+
171+
if state == goptuna.TrialStateComplete {
172+
finalValue = v
173+
}
174+
}
175+
176+
assignments := kt.GetSpec().GetParameterAssignments().GetAssignments()
177+
internalParams, externalParams, err := toGoptunaParams(assignments, searchSpace)
178+
if err != nil {
179+
return nil, err
180+
}
181+
182+
gt := goptuna.FrozenTrial{
183+
ID: i, // dummy id
184+
StudyID: study.ID,
185+
Number: i, // dummy number
186+
State: state,
187+
Value: finalValue,
188+
IntermediateValues: intermediateValues,
189+
DatetimeStart: datetimeStart,
190+
DatetimeComplete: datetimeComplete,
191+
InternalParams: internalParams,
192+
Params: externalParams,
193+
Distributions: searchSpace,
194+
UserAttrs: nil,
195+
SystemAttrs: nil,
196+
}
197+
gtrials = append(gtrials, gt)
198+
}
199+
return gtrials, nil
200+
}
201+
202+
func toGoptunaParams(
203+
assignments []*api_v1_alpha3.ParameterAssignment,
204+
searchSpace map[string]interface{},
205+
) (
206+
internalParams map[string]float64,
207+
externalParams map[string]interface{},
208+
err error,
209+
) {
210+
internalParams = make(map[string]float64, len(assignments))
211+
externalParams = make(map[string]interface{}, len(assignments))
212+
213+
for i := range assignments {
214+
name := assignments[i].GetName()
215+
valueStr := assignments[i].GetValue()
216+
217+
switch d := searchSpace[name].(type) {
218+
case goptuna.UniformDistribution:
219+
p, err := strconv.ParseFloat(valueStr, 64)
220+
if err != nil {
221+
return nil, nil, err
222+
}
223+
internalParams[name] = p
224+
externalParams[name] = d.ToExternalRepr(p)
225+
case goptuna.DiscreteUniformDistribution:
226+
p, err := strconv.ParseFloat(valueStr, 64)
227+
if err != nil {
228+
return nil, nil, err
229+
}
230+
internalParams[name] = p
231+
externalParams[name] = d.ToExternalRepr(p)
232+
case goptuna.IntUniformDistribution:
233+
p, err := strconv.ParseInt(valueStr, 10, 64)
234+
if err != nil {
235+
return nil, nil, err
236+
}
237+
internalParams[name] = float64(p)
238+
externalParams[name] = p
239+
case goptuna.CategoricalDistribution:
240+
internalRepr := -1.0
241+
for i := range d.Choices {
242+
if d.Choices[i] == valueStr {
243+
internalRepr = float64(i)
244+
break
245+
}
246+
}
247+
if internalRepr == -1.0 {
248+
return nil, nil, errors.New("invalid categorical value")
249+
}
250+
internalParams[name] = internalRepr
251+
externalParams[name] = valueStr
252+
}
253+
}
254+
return internalParams, externalParams, nil
255+
}

0 commit comments

Comments
 (0)