Skip to content

Commit 7749171

Browse files
committed
Add a yamlfmt cmd
* Can read 1 file (cmdline) * Can read multiple files (cmdline) * Can read stdin * Can write traditional YAML or KYAML * Can diff input vs output (-d) * Can write results to the input files (-w)
1 parent a932007 commit 7749171

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

yamlfmt/yamlfmt.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
/*
2+
Copyright 2021 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"bytes"
21+
"flag"
22+
"fmt"
23+
"io"
24+
"os"
25+
"path/filepath"
26+
"strings"
27+
28+
yaml "go.yaml.in/yaml/v3"
29+
"sigs.k8s.io/yaml/kyaml"
30+
)
31+
32+
const (
33+
fmtYAML = "yaml"
34+
fmtKYAML = "kyaml"
35+
)
36+
37+
func main() {
38+
fs := flag.NewFlagSet("yamlfmt", flag.ExitOnError)
39+
fs.Usage = func() {
40+
fmt.Fprintf(fs.Output(), "usage: %s [<yaml-files>...]\n", filepath.Base(os.Args[0]))
41+
fmt.Fprintf(fs.Output(), "If no files are specified, stdin will be used.\n")
42+
fs.PrintDefaults()
43+
}
44+
45+
diff := fs.Bool("d", false, "diff input files with their formatted versions")
46+
help := fs.Bool("h", false, "print help and exit")
47+
format := fs.String("o", "yaml", "output format: may be 'yaml' or 'kyaml'")
48+
write := fs.Bool("w", false, "write result to input files instead of stdout")
49+
fs.Parse(os.Args[1:])
50+
51+
if *help {
52+
fs.SetOutput(os.Stdout)
53+
fs.Usage()
54+
os.Exit(0)
55+
}
56+
57+
switch *format {
58+
case "yaml", "kyaml":
59+
// OK
60+
default:
61+
fmt.Fprintf(os.Stderr, "unknown output format %q, must be one of 'yaml' or 'kyaml'\n", *format)
62+
os.Exit(1)
63+
}
64+
if *diff && *write {
65+
fmt.Fprintln(os.Stderr, "cannot use -d and -w together")
66+
}
67+
68+
files := fs.Args()
69+
70+
if len(files) == 0 {
71+
if err := renderYAML("<stdin>", os.Stdin, *format, *diff, os.Stdout); err != nil {
72+
fmt.Fprintln(os.Stderr, err)
73+
os.Exit(1)
74+
}
75+
}
76+
77+
for i, path := range files {
78+
// use a func to catch defer'ed Close
79+
func() {
80+
// Read the YAML file
81+
sourceYaml, err := os.ReadFile(path)
82+
if err != nil {
83+
fmt.Fprintf(os.Stderr, "%s: %v\n", path, err)
84+
return
85+
}
86+
in := bytes.NewReader(sourceYaml)
87+
88+
out := os.Stdout
89+
finalize := func() {}
90+
if *write {
91+
// Write to a temp file and rename when done.
92+
tmp, err := os.CreateTemp(filepath.Dir(path), ".yamlfmt.tmp.")
93+
if err != nil {
94+
fmt.Fprintf(os.Stderr, "%v\n", err)
95+
os.Exit(1)
96+
}
97+
defer tmp.Close()
98+
finalize = func() {
99+
if err := os.Rename(tmp.Name(), path); err != nil {
100+
fmt.Fprintf(os.Stderr, "%v\n", err)
101+
os.Exit(1)
102+
}
103+
}
104+
out = tmp
105+
}
106+
if len(files) > 1 && !*write && !*diff {
107+
if i > 0 {
108+
fmt.Fprintln(out, "")
109+
}
110+
fmt.Fprintln(out, "# "+path)
111+
}
112+
if err := renderYAML(path, in, *format, *diff, out); err != nil {
113+
fmt.Fprintln(os.Stderr, err)
114+
os.Exit(1)
115+
}
116+
finalize()
117+
}()
118+
}
119+
}
120+
121+
func renderYAML(path string, in io.Reader, format string, printDiff bool, out io.Writer) error {
122+
if format == fmtKYAML {
123+
ky := &kyaml.Encoder{}
124+
125+
if printDiff {
126+
ibuf, err := io.ReadAll(in)
127+
if err != nil {
128+
return err
129+
}
130+
obuf := bytes.Buffer{}
131+
if err := ky.FromYAML(bytes.NewReader(ibuf), &obuf); err != nil {
132+
return err
133+
}
134+
d := trivialDiff(path, string(ibuf), obuf.String())
135+
fmt.Fprint(out, d)
136+
return nil
137+
}
138+
139+
return ky.FromYAML(in, out)
140+
}
141+
142+
// else format == fmtYAML
143+
144+
var decoder *yaml.Decoder
145+
var encoder *yaml.Encoder
146+
var finish func()
147+
148+
if printDiff {
149+
ibuf, err := io.ReadAll(in)
150+
if err != nil {
151+
return err
152+
}
153+
obuf := bytes.Buffer{}
154+
decoder = yaml.NewDecoder(bytes.NewReader(ibuf))
155+
encoder = yaml.NewEncoder(&obuf)
156+
finish = func() {
157+
d := trivialDiff(path, string(ibuf), obuf.String())
158+
fmt.Fprint(out, d)
159+
}
160+
} else {
161+
decoder = yaml.NewDecoder(in)
162+
encoder = yaml.NewEncoder(out)
163+
}
164+
encoder.SetIndent(2)
165+
166+
for {
167+
var node yaml.Node // to retain comments
168+
if err := decoder.Decode(&node); err != nil {
169+
if err == io.EOF {
170+
break // End of input
171+
}
172+
return fmt.Errorf("failed to decode input: %w", err)
173+
}
174+
setBlockStyle(&node) // In case we read KYAML as input.
175+
if err := encoder.Encode(&node); err != nil {
176+
return fmt.Errorf("failed to encode node: %w", err)
177+
}
178+
}
179+
if finish != nil {
180+
finish()
181+
}
182+
return nil
183+
}
184+
185+
func trivialDiff(path, a, b string) string {
186+
if a == b {
187+
return ""
188+
}
189+
190+
x := strings.Split(strings.TrimSuffix(a, "\n"), "\n")
191+
y := strings.Split(strings.TrimSuffix(b, "\n"), "\n")
192+
buf := bytes.Buffer{}
193+
buf.WriteString(fmt.Sprintf("--- %s\n+++ %s\n", path, path))
194+
buf.WriteString(fmt.Sprintf("@@ -%d,%d +%d,%d\n", 1, len(x), 1, len(y)))
195+
for {
196+
n := 0
197+
for ; n < len(x) && n < len(y) && x[n] == y[n]; n++ {
198+
buf.WriteString(" " + x[n] + "\n")
199+
}
200+
x = x[n:]
201+
y = y[n:]
202+
203+
nextX, nextY := nextCommon(x, y)
204+
for i := range nextX {
205+
buf.WriteString("-" + x[i] + "\n")
206+
}
207+
x = x[nextX:]
208+
for j := range nextY {
209+
buf.WriteString("+" + y[j] + "\n")
210+
}
211+
y = y[nextY:]
212+
213+
if len(x) == 0 && len(y) == 0 {
214+
break
215+
}
216+
}
217+
return buf.String()
218+
}
219+
220+
func nextCommon(x, y []string) (int, int) {
221+
for i := range len(x) {
222+
for j := range len(y) {
223+
if x[i] == y[j] {
224+
return i, j
225+
}
226+
}
227+
}
228+
return len(x), len(y)
229+
}
230+
231+
func setBlockStyle(node *yaml.Node) {
232+
node.Style = node.Style & (^yaml.FlowStyle)
233+
for _, child := range node.Content {
234+
setBlockStyle(child)
235+
}
236+
}

0 commit comments

Comments
 (0)