Skip to content

Commit 9e8caeb

Browse files
committed
feat(wit-bindgen-go): align flags decoding
Flags encoding/decoding in Go is now rewritten to use the correct binary encoding. Flag structs now contain an addition ReadFromIndex method that can be used for decoding. This makes the generated code more testable.
1 parent 02c4350 commit 9e8caeb

File tree

6 files changed

+178
-55
lines changed

6 files changed

+178
-55
lines changed

crates/wit-bindgen-go/src/interface.rs

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use wit_bindgen_core::wit_parser::{
99
Variant, World, WorldItem, WorldKey,
1010
};
1111
use wit_bindgen_core::{uwrite, uwriteln, Source, TypeInfo};
12-
use wrpc_introspect::{async_paths_ty, flag_repr, is_list_of, is_ty, rpc_func_name};
12+
use wrpc_introspect::{async_paths_ty, is_list_of, is_ty, rpc_func_name};
1313

1414
use crate::{
1515
to_go_ident, to_package_ident, to_upper_camel_case, Deps, GoWrpc, Identifier, InterfaceName,
@@ -737,37 +737,20 @@ impl InterfaceGenerator<'_> {
737737
self.src.push_str(")");
738738
}
739739

740-
fn print_read_flags(&mut self, ty: &Flags, reader: &str, name: &str) {
741-
let fmt = self.deps.fmt();
742-
let io = self.deps.io();
743-
744-
let repr = flag_repr(ty);
740+
fn print_read_flags(&mut self, reader: &str, name: &str) {
741+
let wrpc = self.deps.wrpc();
745742

746743
uwrite!(
747744
self.src,
748-
r#"func(r {io}.ByteReader) (*{name}, error) {{
749-
v := &{name}{{}}
750-
n, err := "#
745+
r#"func(r {wrpc}.IndexReader) (*{name}, error) {{
746+
v := {name}{{}}
747+
if err := v.ReadFromIndex(r); err != nil {{
748+
return nil, err
749+
}}
750+
return &v, nil
751+
}}({reader})
752+
"#
751753
);
752-
self.print_read_discriminant(repr, "r");
753-
self.push_str("\n");
754-
self.push_str("if err != nil {\n");
755-
self.push_str("return nil, ");
756-
self.push_str(fmt);
757-
self.push_str(".Errorf(\"failed to read flag: %w\", err)\n");
758-
self.push_str("}\n");
759-
for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
760-
if i > 64 {
761-
break;
762-
}
763-
uwriteln!(self.src, "if n & (1 << {i}) > 0 {{");
764-
self.push_str("v.");
765-
self.push_str(&name.to_upper_camel_case());
766-
self.push_str(" = true\n");
767-
self.push_str("}\n");
768-
}
769-
self.push_str("return v, nil\n");
770-
uwrite!(self.src, "}}({reader})");
771754
}
772755

773756
fn print_read_enum(&mut self, ty: &Enum, reader: &str, name: &str) {
@@ -1328,8 +1311,8 @@ impl InterfaceGenerator<'_> {
13281311
TypeDefKind::Resource => self.print_read_string(reader),
13291312
TypeDefKind::Handle(Handle::Own(id)) => self.print_read_own(reader, *id),
13301313
TypeDefKind::Handle(Handle::Borrow(id)) => self.print_read_borrow(reader, *id),
1331-
TypeDefKind::Flags(ty) => {
1332-
self.print_read_flags(ty, reader, &name.expect("flag missing a name"));
1314+
TypeDefKind::Flags(_ty) => {
1315+
self.print_read_flags(reader, &name.expect("flag missing a name"));
13331316
}
13341317
TypeDefKind::Tuple(ty) => self.print_read_tuple(ty, reader, path),
13351318
TypeDefKind::Variant(ty) => {
@@ -3743,13 +3726,13 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
37433726
}
37443727

37453728
fn type_flags(&mut self, id: TypeId, _name: &str, ty: &Flags, docs: &Docs) {
3746-
let repr = flag_repr(ty);
3747-
37483729
let info = self.info(id);
37493730
if let Some(name) = self.name_of(id) {
37503731
let strings = self.deps.strings();
37513732
let wrpc = self.deps.wrpc();
3733+
let errors = self.deps.errors();
37523734

3735+
// Struct
37533736
self.godoc(docs);
37543737
uwriteln!(self.src, "type {name} struct {{");
37553738
for Flag { name, docs } in &ty.flags {
@@ -3759,6 +3742,7 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
37593742
}
37603743
self.push_str("}\n");
37613744

3745+
// String()
37623746
uwriteln!(self.src, "func (v *{name}) String() string {{");
37633747
uwriteln!(self.src, "flags := make([]string, 0, {})", ty.flags.len());
37643748
for Flag { name, .. } in &ty.flags {
@@ -3769,34 +3753,76 @@ func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) err
37693753
self.push_str("}\n");
37703754
}
37713755
uwriteln!(self.src, r#"return {strings}.Join(flags, " | ")"#);
3772-
self.push_str("}\n");
3756+
self.push_str("}\n\n");
3757+
3758+
// WriteToIndex()
3759+
let mut buf_len = ty.flags.len() / 8;
3760+
if ty.flags.len() % 8 > 0 {
3761+
buf_len += 1;
3762+
}
3763+
37733764
uwriteln!(
37743765
self.src,
3775-
"func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) error, error) {{"
3766+
r#"func (v *{name}) WriteToIndex(w {wrpc}.ByteWriter) (func({wrpc}.IndexWriter) error, error) {{
3767+
var p [{buf_len}]byte
3768+
"#
37763769
);
3777-
self.push_str("var n ");
3778-
self.int_repr(repr);
3779-
self.push_str("\n");
3770+
37803771
for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
37813772
self.push_str("if v.");
37823773
self.push_str(&name.to_upper_camel_case());
3783-
self.push_str(" {\n");
3784-
if i <= 64 {
3785-
uwriteln!(self.src, "n |= 1 << {i}");
3786-
} else {
3787-
let errors = self.deps.errors();
3788-
uwriteln!(
3789-
self.src,
3790-
r#"return nil, {errors}.New("encoding `{name}` flag value would overflow 64-bit integer, flags containing more than 64 members are not supported yet")"#
3791-
);
3792-
}
3793-
self.push_str("}\n");
3774+
uwriteln!(
3775+
self.src,
3776+
r#"{{
3777+
p[{}] |= 1 << {}
3778+
}}"#,
3779+
i / 8,
3780+
i % 8
3781+
);
37943782
}
3795-
self.push_str("return nil, ");
3796-
self.print_write_discriminant(repr, "n", "w");
3797-
self.push_str("\n");
3798-
self.push_str("}\n");
3783+
uwriteln!(
3784+
self.src,
3785+
r#"
3786+
_, err := w.Write(p[:])
3787+
return nil, err
3788+
}}
3789+
"#,
3790+
);
3791+
3792+
// ReadFromIndex()
3793+
uwrite!(
3794+
self.src,
3795+
r#"func (v *{name}) ReadFromIndex(r {wrpc}.IndexReader) error {{
3796+
var p [{buf_len}]byte
3797+
if _, err := r.Read(p[:]); err != nil {{
3798+
return err
3799+
}}
3800+
3801+
"#
3802+
);
3803+
3804+
for (i, Flag { name, .. }) in ty.flags.iter().enumerate() {
3805+
uwriteln!(
3806+
self.src,
3807+
"v.{} = p[{}] & (1 << {}) > 0",
3808+
name.to_upper_camel_case(),
3809+
i / 8,
3810+
i % 8
3811+
);
3812+
}
3813+
3814+
uwriteln!(
3815+
self.src,
3816+
r#"
3817+
if (p[{}] >> {}) > 0 {{
3818+
return {errors}.New("bit not associated with any flag is set")
3819+
}}"#,
3820+
buf_len - 1,
3821+
ty.flags.len() % 8,
3822+
);
3823+
self.push_str("return nil\n}\n");
37993824

3825+
// Error()
38003826
if info.error {
38013827
uwriteln!(
38023828
self.src,

go.work.sum

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
2+
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
3+
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
24
go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0=
35
golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
46
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
57
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
8+
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
69
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
710
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
811
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=

tests/go/go.mod

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,23 @@ go 1.22.2
55
require (
66
github.com/google/uuid v1.6.0
77
github.com/nats-io/nats-server/v2 v2.10.14
8-
github.com/nats-io/nats.go v1.35.0
8+
github.com/nats-io/nats.go v1.36.0
99
github.com/wrpc/wrpc/go v0.0.0-unpublished
1010
)
1111

1212
require (
13+
github.com/davecgh/go-spew v1.1.1 // indirect
1314
github.com/klauspost/compress v1.17.8 // indirect
1415
github.com/minio/highwayhash v1.0.2 // indirect
1516
github.com/nats-io/jwt/v2 v2.5.5 // indirect
1617
github.com/nats-io/nkeys v0.4.7 // indirect
1718
github.com/nats-io/nuid v1.0.1 // indirect
19+
github.com/pmezard/go-difflib v1.0.0 // indirect
20+
github.com/stretchr/testify v1.9.0
1821
golang.org/x/crypto v0.23.0 // indirect
1922
golang.org/x/sys v0.20.0 // indirect
2023
golang.org/x/time v0.5.0 // indirect
24+
gopkg.in/yaml.v3 v3.0.1 // indirect
2125
)
2226

2327
replace github.com/wrpc/wrpc/go v0.0.0-unpublished => ../../go

tests/go/go.sum

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
13
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
24
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
35
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
@@ -8,16 +10,24 @@ github.com/nats-io/jwt/v2 v2.5.5 h1:ROfXb50elFq5c9+1ztaUbdlrArNFl2+fQWP6B8HGEq4=
810
github.com/nats-io/jwt/v2 v2.5.5/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A=
911
github.com/nats-io/nats-server/v2 v2.10.14 h1:98gPJFOAO2vLdM0gogh8GAiHghwErrSLhugIqzRC+tk=
1012
github.com/nats-io/nats-server/v2 v2.10.14/go.mod h1:a0TwOVBJZz6Hwv7JH2E4ONdpyFk9do0C18TEwxnHdRk=
11-
github.com/nats-io/nats.go v1.35.0 h1:XFNqNM7v5B+MQMKqVGAyHwYhyKb48jrenXNxIU20ULk=
12-
github.com/nats-io/nats.go v1.35.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
13+
github.com/nats-io/nats.go v1.36.0 h1:suEUPuWzTSse/XhESwqLxXGuj8vGRuPRoG7MoRN/qyU=
14+
github.com/nats-io/nats.go v1.36.0/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8=
1315
github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI=
1416
github.com/nats-io/nkeys v0.4.7/go.mod h1:kqXRgRDPlGy7nGaEDMuYzmiJCIAAWDK0IMBtDmGD0nc=
1517
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
1618
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
19+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
20+
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
21+
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
22+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
1723
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
1824
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
1925
golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
2026
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
2127
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
2228
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
2329
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
30+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
31+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
32+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
33+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

tests/go/types_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//go:generate $WIT_BINDGEN_WRPC go --gofmt=false --world types --out-dir bindings/types --package github.com/wrpc/wrpc/tests/go/bindings/types ../wit
2+
3+
package integration_test
4+
5+
import (
6+
"bytes"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
11+
wrpc "github.com/wrpc/wrpc/go"
12+
"github.com/wrpc/wrpc/tests/go/bindings/types/wrpc_test/integration/get_types"
13+
)
14+
15+
type indexReader struct {
16+
*bytes.Buffer
17+
}
18+
19+
func (r *indexReader) Index(path ...uint32) (wrpc.IndexReader, error) {
20+
panic("not implemented")
21+
}
22+
23+
func TestTypes_Flags(t *testing.T) {
24+
t.Run("WriteToIndex", func(t *testing.T) {
25+
check := assert.New(t)
26+
27+
flags := get_types.FeatureFlags{
28+
ShowA: true,
29+
ShowC: true,
30+
}
31+
var b bytes.Buffer
32+
_, err := flags.WriteToIndex(&b)
33+
check.NoError(err)
34+
check.Equal([]byte{0b00000101}, b.Bytes())
35+
})
36+
37+
t.Run("ReadFromIndex", func(t *testing.T) {
38+
check := assert.New(t)
39+
40+
var flags get_types.FeatureFlags
41+
err := flags.ReadFromIndex(&indexReader{
42+
Buffer: bytes.NewBuffer([]byte{0b00000101}),
43+
})
44+
check.NoError(err)
45+
check.Equal(get_types.FeatureFlags{
46+
ShowA: true,
47+
ShowC: true,
48+
}, flags)
49+
50+
t.Run("invalid bit set", func(t *testing.T) {
51+
check := assert.New(t)
52+
53+
var flags get_types.FeatureFlags
54+
err := flags.ReadFromIndex(&indexReader{
55+
Buffer: bytes.NewBuffer([]byte{0b10000000}),
56+
})
57+
if check.Error(err) {
58+
check.Equal("bit not associated with any flag is set", err.Error())
59+
}
60+
check.Zero(flags)
61+
})
62+
})
63+
}

tests/wit/test.wit

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,20 @@ world resources-client {
9494
bar: func(v: borrow<foo>) -> u64;
9595
}
9696
}
97+
98+
interface get-types {
99+
flags feature-flags {
100+
show-a,
101+
show-b,
102+
show-c,
103+
show-d,
104+
show-e,
105+
show-f,
106+
}
107+
108+
get-features: func() -> feature-flags;
109+
}
110+
111+
world types {
112+
import get-types;
113+
}

0 commit comments

Comments
 (0)