Skip to content

Commit 6fb5325

Browse files
authored
Implement "import public" using type aliases. (#583)
Type aliases were added in Go 1.9, so this change bumps the minium required Go version for protos which use public imports.
1 parent 9bb8760 commit 6fb5325

File tree

19 files changed

+518
-1197
lines changed

19 files changed

+518
-1197
lines changed

protoc-gen-go/generator/generator.go

Lines changed: 28 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343
"crypto/sha256"
4444
"encoding/hex"
4545
"fmt"
46+
"go/build"
4647
"go/parser"
4748
"go/printer"
4849
"go/token"
@@ -345,8 +346,7 @@ type symbol interface {
345346
type messageSymbol struct {
346347
sym string
347348
hasExtensions, isMessageSet bool
348-
hasOneof bool
349-
getters []getterSymbol
349+
oneofTypes []string
350350
}
351351

352352
type getterSymbol struct {
@@ -357,146 +357,10 @@ type getterSymbol struct {
357357
}
358358

359359
func (ms *messageSymbol) GenerateAlias(g *Generator, pkg GoPackageName) {
360-
remoteSym := string(pkg) + "." + ms.sym
361-
362-
g.P("type ", ms.sym, " ", remoteSym)
363-
g.P("func (m *", ms.sym, ") Reset() { (*", remoteSym, ")(m).Reset() }")
364-
g.P("func (m *", ms.sym, ") String() string { return (*", remoteSym, ")(m).String() }")
365-
g.P("func (*", ms.sym, ") ProtoMessage() {}")
366-
g.P("func (m *", ms.sym, ") XXX_Unmarshal(buf []byte) error ",
367-
"{ return (*", remoteSym, ")(m).XXX_Unmarshal(buf) }")
368-
g.P("func (m *", ms.sym, ") XXX_Marshal(b []byte, deterministic bool) ([]byte, error) ",
369-
"{ return (*", remoteSym, ")(m).XXX_Marshal(b, deterministic) }")
370-
g.P("func (m *", ms.sym, ") XXX_Size() int ",
371-
"{ return (*", remoteSym, ")(m).XXX_Size() }")
372-
g.P("func (m *", ms.sym, ") XXX_DiscardUnknown() ",
373-
"{ (*", remoteSym, ")(m).XXX_DiscardUnknown() }")
374-
if ms.hasExtensions {
375-
g.P("func (*", ms.sym, ") ExtensionRangeArray() []", g.Pkg["proto"], ".ExtensionRange ",
376-
"{ return (*", remoteSym, ")(nil).ExtensionRangeArray() }")
377-
}
378-
if ms.hasOneof {
379-
// Oneofs and public imports do not mix well.
380-
// We can make them work okay for the binary format,
381-
// but they're going to break weirdly for text/JSON.
382-
enc := "_" + ms.sym + "_OneofMarshaler"
383-
dec := "_" + ms.sym + "_OneofUnmarshaler"
384-
size := "_" + ms.sym + "_OneofSizer"
385-
encSig := "(msg " + g.Pkg["proto"] + ".Message, b *" + g.Pkg["proto"] + ".Buffer) error"
386-
decSig := "(msg " + g.Pkg["proto"] + ".Message, tag, wire int, b *" + g.Pkg["proto"] + ".Buffer) (bool, error)"
387-
sizeSig := "(msg " + g.Pkg["proto"] + ".Message) int"
388-
g.P("func (m *", ms.sym, ") XXX_OneofFuncs() (func", encSig, ", func", decSig, ", func", sizeSig, ", []interface{}) {")
389-
g.P("_, _, _, x := (*", remoteSym, ")(nil).XXX_OneofFuncs()")
390-
g.P("return ", enc, ", ", dec, ", ", size, ", x")
391-
g.P("}")
392-
393-
g.P("func ", enc, encSig, " {")
394-
g.P("m := msg.(*", ms.sym, ")")
395-
g.P("m0 := (*", remoteSym, ")(m)")
396-
g.P("enc, _, _, _ := m0.XXX_OneofFuncs()")
397-
g.P("return enc(m0, b)")
398-
g.P("}")
399-
400-
g.P("func ", dec, decSig, " {")
401-
g.P("m := msg.(*", ms.sym, ")")
402-
g.P("m0 := (*", remoteSym, ")(m)")
403-
g.P("_, dec, _, _ := m0.XXX_OneofFuncs()")
404-
g.P("return dec(m0, tag, wire, b)")
405-
g.P("}")
406-
407-
g.P("func ", size, sizeSig, " {")
408-
g.P("m := msg.(*", ms.sym, ")")
409-
g.P("m0 := (*", remoteSym, ")(m)")
410-
g.P("_, _, size, _ := m0.XXX_OneofFuncs()")
411-
g.P("return size(m0)")
412-
g.P("}")
413-
}
414-
for _, get := range ms.getters {
415-
416-
if get.typeName != "" {
417-
g.RecordTypeUse(get.typeName)
418-
}
419-
typ := get.typ
420-
val := "(*" + remoteSym + ")(m)." + get.name + "()"
421-
if get.genType {
422-
// typ will be "*pkg.T" (message/group) or "pkg.T" (enum)
423-
// or "map[t]*pkg.T" (map to message/enum).
424-
// The first two of those might have a "[]" prefix if it is repeated.
425-
// Drop any package qualifier since we have hoisted the type into this package.
426-
rep := strings.HasPrefix(typ, "[]")
427-
if rep {
428-
typ = typ[2:]
429-
}
430-
isMap := strings.HasPrefix(typ, "map[")
431-
star := typ[0] == '*'
432-
if !isMap { // map types handled lower down
433-
typ = typ[strings.Index(typ, ".")+1:]
434-
}
435-
if star {
436-
typ = "*" + typ
437-
}
438-
if rep {
439-
// Go does not permit conversion between slice types where both
440-
// element types are named. That means we need to generate a bit
441-
// of code in this situation.
442-
// typ is the element type.
443-
// val is the expression to get the slice from the imported type.
444-
445-
ctyp := typ // conversion type expression; "Foo" or "(*Foo)"
446-
if star {
447-
ctyp = "(" + typ + ")"
448-
}
449-
450-
g.P("func (m *", ms.sym, ") ", get.name, "() []", typ, " {")
451-
g.In()
452-
g.P("o := ", val)
453-
g.P("if o == nil {")
454-
g.In()
455-
g.P("return nil")
456-
g.Out()
457-
g.P("}")
458-
g.P("s := make([]", typ, ", len(o))")
459-
g.P("for i, x := range o {")
460-
g.In()
461-
g.P("s[i] = ", ctyp, "(x)")
462-
g.Out()
463-
g.P("}")
464-
g.P("return s")
465-
g.Out()
466-
g.P("}")
467-
continue
468-
}
469-
if isMap {
470-
// Split map[keyTyp]valTyp.
471-
bra, ket := strings.Index(typ, "["), strings.Index(typ, "]")
472-
keyTyp, valTyp := typ[bra+1:ket], typ[ket+1:]
473-
// Drop any package qualifier.
474-
// Only the value type may be foreign.
475-
star := valTyp[0] == '*'
476-
valTyp = valTyp[strings.Index(valTyp, ".")+1:]
477-
if star {
478-
valTyp = "*" + valTyp
479-
}
480-
481-
typ := "map[" + keyTyp + "]" + valTyp
482-
g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " {")
483-
g.P("o := ", val)
484-
g.P("if o == nil { return nil }")
485-
g.P("s := make(", typ, ", len(o))")
486-
g.P("for k, v := range o {")
487-
g.P("s[k] = (", valTyp, ")(v)")
488-
g.P("}")
489-
g.P("return s")
490-
g.P("}")
491-
continue
492-
}
493-
// Convert imported type into the forwarding type.
494-
val = "(" + typ + ")(" + val + ")"
495-
}
496-
497-
g.P("func (m *", ms.sym, ") ", get.name, "() ", typ, " { return ", val, " }")
360+
g.P("type ", ms.sym, " = ", pkg, ".", ms.sym)
361+
for _, name := range ms.oneofTypes {
362+
g.P("type ", name, " = ", pkg, ".", name)
498363
}
499-
500364
}
501365

502366
type enumSymbol struct {
@@ -506,14 +370,9 @@ type enumSymbol struct {
506370

507371
func (es enumSymbol) GenerateAlias(g *Generator, pkg GoPackageName) {
508372
s := es.name
509-
g.P("type ", s, " ", pkg, ".", s)
373+
g.P("type ", s, " = ", pkg, ".", s)
510374
g.P("var ", s, "_name = ", pkg, ".", s, "_name")
511375
g.P("var ", s, "_value = ", pkg, ".", s, "_value")
512-
g.P("func (x ", s, ") String() string { return (", pkg, ".", s, ")(x).String() }")
513-
if !es.proto3 {
514-
g.P("func (x ", s, ") Enum() *", s, "{ return (*", s, ")((", pkg, ".", s, ")(x).Enum()) }")
515-
g.P("func (x *", s, ") UnmarshalJSON(data []byte) error { return (*", pkg, ".", s, ")(x).UnmarshalJSON(data) }")
516-
}
517376
}
518377

519378
type constOrVarSymbol struct {
@@ -1486,20 +1345,18 @@ func (g *Generator) generateImports() {
14861345
}
14871346

14881347
func (g *Generator) generateImported(id *ImportedDescriptor) {
1489-
// Don't generate public import symbols for files that we are generating
1490-
// code for, since those symbols will already be in this package.
1491-
// We can't simply avoid creating the ImportedDescriptor objects,
1492-
// because g.genFiles isn't populated at that stage.
14931348
tn := id.TypeName()
14941349
sn := tn[len(tn)-1]
14951350
df := id.o.File()
14961351
filename := *df.Name
1497-
for _, fd := range g.genFiles {
1498-
if *fd.Name == filename {
1499-
g.P("// Ignoring public import of ", sn, " from ", filename)
1500-
g.P()
1501-
return
1502-
}
1352+
if df.importPath == g.file.importPath {
1353+
// Don't generate type aliases for files in the same Go package as this one.
1354+
g.P("// Ignoring public import of ", sn, " from ", filename)
1355+
g.P()
1356+
return
1357+
}
1358+
if !supportTypeAliases {
1359+
g.Fail(fmt.Sprintf("%s: public imports require at least go1.9", filename))
15031360
}
15041361
g.P("// ", sn, " from public import ", filename)
15051362
g.usedPackages[df.importPath] = true
@@ -2232,6 +2089,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
22322089
g.P("}")
22332090
}
22342091
g.P()
2092+
var oneofTypes []string
22352093
for i, field := range message.Field {
22362094
if field.OneofIndex == nil {
22372095
continue
@@ -2241,6 +2099,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
22412099
fieldFullPath := fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i)
22422100
g.P("type ", Annotate(message.file, fieldFullPath, oneofTypeName[field]), " struct{ ", Annotate(message.file, fieldFullPath, fieldNames[field]), " ", fieldTypes[field], " `", tag, "` }")
22432101
g.RecordTypeUse(field.GetTypeName())
2102+
oneofTypes = append(oneofTypes, oneofTypeName[field])
22442103
}
22452104
g.P()
22462105
for _, field := range message.Field {
@@ -2261,7 +2120,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
22612120
g.P()
22622121

22632122
// Field getters
2264-
var getters []getterSymbol
22652123
for i, field := range message.Field {
22662124
oneof := field.OneofIndex != nil
22672125

@@ -2278,42 +2136,6 @@ func (g *Generator) generateMessage(message *Descriptor) {
22782136
}
22792137
fieldFullPath := fmt.Sprintf("%s,%d,%d", message.path, messageFieldPath, i)
22802138

2281-
// Only export getter symbols for basic types,
2282-
// and for messages and enums in the same package.
2283-
// Groups are not exported.
2284-
// Foreign types can't be hoisted through a public import because
2285-
// the importer may not already be importing the defining .proto.
2286-
// As an example, imagine we have an import tree like this:
2287-
// A.proto -> B.proto -> C.proto
2288-
// If A publicly imports B, we need to generate the getters from B in A's output,
2289-
// but if one such getter returns something from C then we cannot do that
2290-
// because A is not importing C already.
2291-
var getter, genType bool
2292-
switch *field.Type {
2293-
case descriptor.FieldDescriptorProto_TYPE_GROUP:
2294-
getter = false
2295-
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_ENUM:
2296-
// Only export getter if its return type is in the same file.
2297-
//
2298-
// This should be the same package, not the same file.
2299-
// However, code elsewhere assumes that there's a 1-1 relationship
2300-
// between packages and files, so that's not safe.
2301-
//
2302-
// TODO: Tear out all of this complexity and just use type aliases.
2303-
getter = g.ObjectNamed(field.GetTypeName()).File() == message.File()
2304-
genType = true
2305-
default:
2306-
getter = true
2307-
}
2308-
if getter {
2309-
getters = append(getters, getterSymbol{
2310-
name: mname,
2311-
typ: typename,
2312-
typeName: field.GetTypeName(),
2313-
genType: genType,
2314-
})
2315-
}
2316-
23172139
if field.GetOptions().GetDeprecated() {
23182140
g.P(deprecationComment)
23192141
}
@@ -2416,8 +2238,7 @@ func (g *Generator) generateMessage(message *Descriptor) {
24162238
sym: ccTypeName,
24172239
hasExtensions: hasExtensions,
24182240
isMessageSet: isMessageSet,
2419-
hasOneof: len(message.OneofDecl) > 0,
2420-
getters: getters,
2241+
oneofTypes: oneofTypes,
24212242
}
24222243
g.file.addExport(message, ms)
24232244
}
@@ -3094,3 +2915,14 @@ const (
30942915
// tag numbers in EnumDescriptorProto
30952916
enumValuePath = 2 // value
30962917
)
2918+
2919+
var supportTypeAliases bool
2920+
2921+
func init() {
2922+
for _, tag := range build.Default.ReleaseTags {
2923+
if tag == "go1.9" {
2924+
supportTypeAliases = true
2925+
return
2926+
}
2927+
}
2928+
}

protoc-gen-go/golden_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"flag"
66
"fmt"
7+
"go/build"
78
"go/parser"
89
"go/token"
910
"io/ioutil"
@@ -38,8 +39,13 @@ func TestGolden(t *testing.T) {
3839

3940
// Find all the proto files we need to compile. We assume that each directory
4041
// contains the files for a single package.
42+
supportTypeAliases := hasReleaseTag("1.9")
4143
packages := map[string][]string{}
4244
err = filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error {
45+
if filepath.Base(path) == "import_public" && !supportTypeAliases {
46+
// Public imports require type alias support.
47+
return filepath.SkipDir
48+
}
4349
if !strings.HasSuffix(path, ".proto") {
4450
return nil
4551
}
@@ -405,3 +411,12 @@ func protoc(t *testing.T, args []string) {
405411
t.Fatalf("protoc: %v", err)
406412
}
407413
}
414+
415+
func hasReleaseTag(want string) bool {
416+
for _, tag := range build.Default.ReleaseTags {
417+
if tag == want {
418+
return true
419+
}
420+
}
421+
return false
422+
}

0 commit comments

Comments
 (0)