Skip to content

Commit 87d2a39

Browse files
committed
review: convert fileFormat to type FileSystemFileFormat
1 parent 1fb26bb commit 87d2a39

File tree

3 files changed

+25
-23
lines changed

3 files changed

+25
-23
lines changed

cmd/metricscollector/v1beta1/file-metricscollector/main.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func printMetricsFile(mFile string) {
140140
}
141141
}
142142

143-
func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
143+
func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string, fileFormat commonv1beta1.FileSystemFileFormat) {
144144

145145
// metricStartStep is the dict where key = metric name, value = start step.
146146
// We should apply early stopping rule only if metric is reported at least "start_step" times.
@@ -174,7 +174,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
174174

175175
// Get list of regural expressions from filters.
176176
var metricRegList []*regexp.Regexp
177-
if *metricsFileFormat == commonv1beta1.TextFormat.String() {
177+
if fileFormat == commonv1beta1.TextFormat {
178178
metricRegList = filemc.GetFilterRegexpList(filters)
179179
}
180180

@@ -185,8 +185,8 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
185185
// Print log line
186186
klog.Info(logText)
187187

188-
switch *metricsFileFormat {
189-
case commonv1beta1.TextFormat.String():
188+
switch fileFormat {
189+
case commonv1beta1.TextFormat:
190190
// Check if log line contains metric from stop rules.
191191
isRuleLine := false
192192
for _, rule := range stopRules {
@@ -224,7 +224,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
224224
}
225225
}
226226
}
227-
case commonv1beta1.JsonFormat.String():
227+
case commonv1beta1.JsonFormat:
228228
var logJsonObj map[string]interface{}
229229
if err = json.Unmarshal([]byte(logText), &logJsonObj); err != nil {
230230
klog.Fatalf("Failed to unmarshal logs in JSON format, log: %s, error: %v", logText, err)
@@ -256,7 +256,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
256256
stopRules = updateStopRules(objMetric, stopRules, optimalObjValue, metricValue, objType, metricStartStep, rule, idx)
257257
}
258258
default:
259-
klog.Fatalf("format must be set %s or %s", commonv1beta1.TextFormat.String(), commonv1beta1.JsonFormat.String())
259+
klog.Fatalf("format must be set %v or %v", commonv1beta1.TextFormat, commonv1beta1.JsonFormat)
260260
}
261261

262262
// If stopRules array is empty, Trial is early stopped.
@@ -295,7 +295,7 @@ func watchMetricsFile(mFile string, stopRules stopRulesFlag, filters []string) {
295295
}
296296

297297
// Report metrics to DB.
298-
reportMetrics(filters)
298+
reportMetrics(filters, fileFormat)
299299

300300
// Wait until main process is completed.
301301
timeout := 60 * time.Second
@@ -400,9 +400,11 @@ func main() {
400400
filters = strings.Split(*metricFilters, ";")
401401
}
402402

403+
fileFormat := commonv1beta1.FileSystemFileFormat(*metricsFileFormat)
404+
403405
// If stop rule is set we need to parse metrics during run.
404406
if len(stopRules) != 0 {
405-
go watchMetricsFile(*metricsFilePath, stopRules, filters)
407+
go watchMetricsFile(*metricsFilePath, stopRules, filters, fileFormat)
406408
} else {
407409
go printMetricsFile(*metricsFilePath)
408410
}
@@ -421,11 +423,11 @@ func main() {
421423

422424
// If training was not early stopped, report the metrics.
423425
if !isEarlyStopped {
424-
reportMetrics(filters)
426+
reportMetrics(filters, fileFormat)
425427
}
426428
}
427429

428-
func reportMetrics(filters []string) {
430+
func reportMetrics(filters []string, fileFormat commonv1beta1.FileSystemFileFormat) {
429431

430432
conn, err := grpc.Dial(*dbManagerServiceAddr, grpc.WithInsecure())
431433
if err != nil {
@@ -438,7 +440,7 @@ func reportMetrics(filters []string) {
438440
if len(*metricNames) != 0 {
439441
metricList = strings.Split(*metricNames, ";")
440442
}
441-
olog, err := filemc.CollectObservationLog(*metricsFilePath, metricList, filters, *metricsFileFormat)
443+
olog, err := filemc.CollectObservationLog(*metricsFilePath, metricList, filters, fileFormat)
442444
if err != nil {
443445
klog.Fatalf("Failed to collect logs: %v", err)
444446
}

pkg/metricscollector/v1beta1/file-metricscollector/file-metricscollector.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import (
3434
"github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common"
3535
)
3636

37-
func CollectObservationLog(fileName string, metrics []string, filters []string, format string) (*v1beta1.ObservationLog, error) {
37+
func CollectObservationLog(fileName string, metrics []string, filters []string, fileFormat commonv1beta1.FileSystemFileFormat) (*v1beta1.ObservationLog, error) {
3838
file, err := os.Open(fileName)
3939
if err != nil {
4040
return nil, err
@@ -46,13 +46,13 @@ func CollectObservationLog(fileName string, metrics []string, filters []string,
4646
}
4747
logs := string(content)
4848

49-
switch format {
50-
case commonv1beta1.TextFormat.String():
49+
switch fileFormat {
50+
case commonv1beta1.TextFormat:
5151
return parseLogsInTextFormat(strings.Split(logs, "\n"), metrics, filters)
52-
case commonv1beta1.JsonFormat.String():
52+
case commonv1beta1.JsonFormat:
5353
return parseLogsInJsonFormat(strings.Split(logs, "\n"), metrics)
5454
}
55-
return nil, fmt.Errorf("format must be set %s or %s", commonv1beta1.TextFormat.String(), commonv1beta1.JsonFormat.String())
55+
return nil, fmt.Errorf("format must be set %v or %s", commonv1beta1.TextFormat, commonv1beta1.JsonFormat)
5656
}
5757

5858
func parseLogsInTextFormat(logs []string, metrics []string, filters []string) (*v1beta1.ObservationLog, error) {

pkg/metricscollector/v1beta1/file-metricscollector/file-metricscollector_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ func TestCollectObservationLog(t *testing.T) {
3838
fileName string
3939
metrics []string
4040
filters []string
41-
format string
41+
fileFormat commonv1beta1.FileSystemFileFormat
4242
err bool
4343
expected *v1beta1.ObservationLog
4444
}{
4545
{
4646
description: "Positive case for logs in JSON format",
4747
fileName: path.Join(testJsonDataPath, "good.json"),
4848
metrics: []string{"acc", "loss"},
49-
format: commonv1beta1.JsonFormat.String(),
49+
fileFormat: commonv1beta1.JsonFormat,
5050
expected: &v1beta1.ObservationLog{
5151
MetricLogs: []*v1beta1.MetricLog{
5252
{
@@ -95,19 +95,19 @@ func TestCollectObservationLog(t *testing.T) {
9595
{
9696
description: "Invalid file format",
9797
fileName: path.Join(testJsonDataPath, "good.json"),
98-
format: "invalid",
98+
fileFormat: "invalid",
9999
err: true,
100100
},
101101
{
102102
description: "Invalid formatted file for logs in JSON format",
103103
fileName: path.Join(testJsonDataPath, "invalid-format.json"),
104-
format: commonv1beta1.JsonFormat.String(),
104+
fileFormat: commonv1beta1.JsonFormat,
105105
err: true,
106106
},
107107
{
108108
description: "Invalid timestamp for logs in JSON format",
109109
fileName: path.Join(testJsonDataPath, "invalid-timestamp.json"),
110-
format: commonv1beta1.JsonFormat.String(),
110+
fileFormat: commonv1beta1.JsonFormat,
111111
metrics: []string{"acc", "loss"},
112112
expected: &v1beta1.ObservationLog{
113113
MetricLogs: []*v1beta1.MetricLog{
@@ -131,7 +131,7 @@ func TestCollectObservationLog(t *testing.T) {
131131
{
132132
description: "Missing objective metric in training logs",
133133
fileName: path.Join(testJsonDataPath, "missing-objective-metric.json"),
134-
format: commonv1beta1.JsonFormat.String(),
134+
fileFormat: commonv1beta1.JsonFormat,
135135
metrics: []string{"acc", "loss"},
136136
expected: &v1beta1.ObservationLog{
137137
MetricLogs: []*v1beta1.MetricLog{
@@ -149,7 +149,7 @@ func TestCollectObservationLog(t *testing.T) {
149149

150150
for _, test := range testCases {
151151
t.Run(test.description, func(t *testing.T) {
152-
actual, err := CollectObservationLog(test.fileName, test.metrics, test.filters, test.format)
152+
actual, err := CollectObservationLog(test.fileName, test.metrics, test.filters, test.fileFormat)
153153
if (err != nil) != test.err {
154154
t.Errorf("\nGOT: \n%v\nWANT: %v\n", err, test.err)
155155
} else {

0 commit comments

Comments
 (0)