Skip to content

Commit a7e4ec7

Browse files
committed
Use ReadOnlyEntry for GetAuthorizedEntries API
Signed-off-by: Sorin Dumitru <[email protected]>
1 parent 38905a1 commit a7e4ec7

File tree

16 files changed

+119
-96
lines changed

16 files changed

+119
-96
lines changed

pkg/agent/client/client_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
svidv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/server/svid/v1"
2222
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
2323
"github.com/spiffe/spire/pkg/common/telemetry"
24+
"github.com/spiffe/spire/pkg/server/api"
2425
"github.com/spiffe/spire/pkg/server/api/entry/v1"
2526
"github.com/spiffe/spire/proto/spire/common"
2627
"github.com/spiffe/spire/test/spiretest"
@@ -1023,7 +1024,13 @@ func (c *fakeEntryServer) GetAuthorizedEntries(_ context.Context, in *entryv1.Ge
10231024

10241025
func (c *fakeEntryServer) SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer) error {
10251026
const entryPageSize = 2
1026-
return entry.SyncAuthorizedEntries(stream, c.entries, entryPageSize)
1027+
1028+
entries := []api.ReadOnlyEntry{}
1029+
for _, entry := range c.entries {
1030+
entries = append(entries, api.NewReadOnlyEntry(entry))
1031+
}
1032+
1033+
return entry.SyncAuthorizedEntries(stream, entries, entryPageSize)
10271034
}
10281035

10291036
type fakeBundleServer struct {

pkg/server/api/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type AuthorizedEntryFetcher interface {
1818
LookupAuthorizedEntries(ctx context.Context, id spiffeid.ID, entryIDs map[string]struct{}) (map[string]ReadOnlyEntry, error)
1919
// FetchAuthorizedEntries fetches the entries that the specified
2020
// SPIFFE ID is authorized for
21-
FetchAuthorizedEntries(ctx context.Context, id spiffeid.ID) ([]*types.Entry, error)
21+
FetchAuthorizedEntries(ctx context.Context, id spiffeid.ID) ([]ReadOnlyEntry, error)
2222
}
2323

2424
// AuthorizedEntryFetcherFunc is an implementation of AuthorizedEntryFetcher

pkg/server/api/entry.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ func NewReadOnlyEntry(entry *types.Entry) ReadOnlyEntry {
2828
}
2929
}
3030

31+
func (e ReadOnlyEntry) GetId() string {
32+
return e.entry.Id
33+
}
34+
3135
func (e *ReadOnlyEntry) GetSpiffeId() *types.SPIFFEID {
3236
return &types.SPIFFEID{
3337
TrustDomain: e.entry.SpiffeId.TrustDomain,
@@ -51,6 +55,10 @@ func (e *ReadOnlyEntry) GetRevisionNumber() int64 {
5155
return e.entry.RevisionNumber
5256
}
5357

58+
func (e *ReadOnlyEntry) GetCreatedAt() int64 {
59+
return e.entry.CreatedAt
60+
}
61+
5462
// Manually clone the entry instead of using the protobuf helpers
5563
// since those are two times slower.
5664
func (e *ReadOnlyEntry) Clone(mask *types.EntryMask) *types.Entry {

pkg/server/api/entry/v1/service.go

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,12 @@ func (s *Service) GetAuthorizedEntries(ctx context.Context, req *entryv1.GetAuth
391391
return nil, err
392392
}
393393

394-
for i, entry := range entries {
395-
applyMask(entry, req.OutputMask)
396-
entries[i] = entry
397-
}
394+
resp := &entryv1.GetAuthorizedEntriesResponse{}
398395

399-
resp := &entryv1.GetAuthorizedEntriesResponse{
400-
Entries: entries,
396+
for _, entry := range entries {
397+
resp.Entries = append(resp.Entries, entry.Clone(req.OutputMask))
401398
}
399+
402400
rpccontext.AuditRPC(ctx)
403401

404402
return resp, nil
@@ -424,7 +422,7 @@ func (s *Service) SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntri
424422
return SyncAuthorizedEntries(stream, entries, s.entryPageSize)
425423
}
426424

427-
func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, entries []*types.Entry, entryPageSize int) (err error) {
425+
func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, entries []api.ReadOnlyEntry, entryPageSize int) (err error) {
428426
// Receive the initial request with the output mask.
429427
req, err := stream.Recv()
430428
if err != nil {
@@ -447,18 +445,17 @@ func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, ent
447445

448446
// Apply output mask to entries. The output mask field will be
449447
// intentionally ignored on subsequent requests.
450-
for i, entry := range entries {
451-
applyMask(entry, req.OutputMask)
452-
entries[i] = entry
453-
}
448+
initialOutputMask := req.OutputMask
454449

455450
// If the number of entries is less than or equal to the entry page size,
456451
// then just send the full list back. Otherwise, we'll send a sparse list
457452
// and then stream back full entries as requested.
458453
if len(entries) <= entryPageSize {
459-
return stream.Send(&entryv1.SyncAuthorizedEntriesResponse{
460-
Entries: entries,
461-
})
454+
resp := &entryv1.SyncAuthorizedEntriesResponse{}
455+
for _, entry := range entries {
456+
resp.Entries = append(resp.Entries, entry.Clone(initialOutputMask))
457+
}
458+
return stream.Send(resp)
462459
}
463460

464461
// Prepopulate the entry page used in the response with empty entry structs.
@@ -475,9 +472,9 @@ func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, ent
475472
more = true
476473
}
477474
for j, entry := range entries[i : i+n] {
478-
entryRevisions[j].Id = entry.Id
479-
entryRevisions[j].RevisionNumber = entry.RevisionNumber
480-
entryRevisions[j].CreatedAt = entry.CreatedAt
475+
entryRevisions[j].Id = entry.GetId()
476+
entryRevisions[j].RevisionNumber = entry.GetRevisionNumber()
477+
entryRevisions[j].CreatedAt = entry.GetCreatedAt()
481478
}
482479

483480
if err := stream.Send(&entryv1.SyncAuthorizedEntriesResponse{
@@ -530,7 +527,7 @@ func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, ent
530527
entriesToSearch := entries
531528
for _, id := range req.Ids {
532529
i, found := sort.Find(len(entriesToSearch), func(i int) int {
533-
return strings.Compare(id, entriesToSearch[i].Id)
530+
return strings.Compare(id, entriesToSearch[i].GetId())
534531
})
535532
if found {
536533
if len(resp.Entries) == entryPageSize {
@@ -543,7 +540,7 @@ func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, ent
543540
}
544541
resp.Entries = resp.Entries[:0]
545542
}
546-
resp.Entries = append(resp.Entries, entriesToSearch[i])
543+
resp.Entries = append(resp.Entries, entriesToSearch[i].Clone(initialOutputMask))
547544
}
548545
entriesToSearch = entriesToSearch[i:]
549546
if len(entriesToSearch) == 0 {
@@ -560,7 +557,7 @@ func SyncAuthorizedEntries(stream entryv1.Entry_SyncAuthorizedEntriesServer, ent
560557
}
561558

562559
// fetchEntries fetches authorized entries using caller ID from context
563-
func (s *Service) fetchEntries(ctx context.Context, log logrus.FieldLogger) ([]*types.Entry, error) {
560+
func (s *Service) fetchEntries(ctx context.Context, log logrus.FieldLogger) ([]api.ReadOnlyEntry, error) {
564561
callerID, ok := rpccontext.CallerID(ctx)
565562
if !ok {
566563
return nil, api.MakeErr(log, codes.Internal, "caller ID missing from request context", nil)
@@ -844,8 +841,8 @@ func fieldsFromCountEntryFilter(ctx context.Context, td spiffeid.TrustDomain, fi
844841
return fields
845842
}
846843

847-
func sortEntriesByID(entries []*types.Entry) {
844+
func sortEntriesByID(entries []api.ReadOnlyEntry) {
848845
sort.Slice(entries, func(a, b int) bool {
849-
return entries[a].Id < entries[b].Id
846+
return entries[a].GetId() < entries[b].GetId()
850847
})
851848
}

pkg/server/api/entry/v1/service_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4849,13 +4849,13 @@ func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spif
48494849

48504850
entriesMap := make(map[string]api.ReadOnlyEntry)
48514851
for _, entry := range entries {
4852-
entriesMap[entry.GetId()] = api.NewReadOnlyEntry(entry)
4852+
entriesMap[entry.GetId()] = entry
48534853
}
48544854

48554855
return entriesMap, nil
48564856
}
48574857

4858-
func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
4858+
func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]api.ReadOnlyEntry, error) {
48594859
if f.err != "" {
48604860
return nil, status.Error(codes.Internal, f.err)
48614861
}
@@ -4869,7 +4869,12 @@ func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiff
48694869
return nil, fmt.Errorf("provided caller id is different to expected")
48704870
}
48714871

4872-
return f.entries, nil
4872+
entries := []api.ReadOnlyEntry{}
4873+
for _, entry := range f.entries {
4874+
entries = append(entries, api.NewReadOnlyEntry(entry))
4875+
}
4876+
4877+
return entries, nil
48734878
}
48744879

48754880
type HasID interface {

pkg/server/api/svid/v1/service_test.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,7 @@ func TestServiceBatchNewX509SVID(t *testing.T) {
11581158
}
11591159
invalidEntry := &types.Entry{
11601160
Id: "invalid",
1161+
SpiffeId: &types.SPIFFEID{},
11611162
ParentId: api.ProtoFromID(agentID),
11621163
}
11631164
test.ef.entries = []*types.Entry{workloadEntry, dnsEntry, ttlEntry, x509TtlEntry, invalidEntry}
@@ -1324,7 +1325,7 @@ func TestServiceBatchNewX509SVID(t *testing.T) {
13241325
Message: "Entry has malformed SPIFFE ID",
13251326
Data: logrus.Fields{
13261327
telemetry.RegistrationID: "invalid",
1327-
logrus.ErrorKey: "request must specify SPIFFE ID",
1328+
logrus.ErrorKey: "trust domain is missing",
13281329
},
13291330
},
13301331
{
@@ -1336,7 +1337,7 @@ func TestServiceBatchNewX509SVID(t *testing.T) {
13361337
telemetry.RegistrationID: "invalid",
13371338
telemetry.Csr: api.HashByte(m["invalid"]),
13381339
telemetry.StatusCode: "Internal",
1339-
telemetry.StatusMessage: "entry has malformed SPIFFE ID: request must specify SPIFFE ID",
1340+
telemetry.StatusMessage: "entry has malformed SPIFFE ID: trust domain is missing",
13401341
telemetry.SPIFFEID: "",
13411342
},
13421343
},
@@ -1656,7 +1657,7 @@ func TestServiceBatchNewX509SVID(t *testing.T) {
16561657
Message: "Entry has malformed SPIFFE ID",
16571658
Data: logrus.Fields{
16581659
telemetry.RegistrationID: "invalid",
1659-
logrus.ErrorKey: "request must specify SPIFFE ID",
1660+
logrus.ErrorKey: "trust domain is missing",
16601661
},
16611662
},
16621663
{
@@ -1668,7 +1669,7 @@ func TestServiceBatchNewX509SVID(t *testing.T) {
16681669
telemetry.RegistrationID: "invalid",
16691670
telemetry.Csr: api.HashByte(m["invalid"]),
16701671
telemetry.StatusCode: "Internal",
1671-
telemetry.StatusMessage: "entry has malformed SPIFFE ID: request must specify SPIFFE ID",
1672+
telemetry.StatusMessage: "entry has malformed SPIFFE ID: trust domain is missing",
16721673
telemetry.SPIFFEID: "",
16731674
},
16741675
},
@@ -2181,27 +2182,32 @@ func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spif
21812182

21822183
entriesMap := make(map[string]api.ReadOnlyEntry)
21832184
for _, entry := range entries {
2184-
entriesMap[entry.GetId()] = api.NewReadOnlyEntry(entry)
2185+
entriesMap[entry.GetId()] = entry
21852186
}
21862187

21872188
return entriesMap, nil
21882189
}
21892190

2190-
func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
2191+
func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]api.ReadOnlyEntry, error) {
21912192
if f.err != "" {
21922193
return nil, status.Error(codes.Internal, f.err)
21932194
}
21942195

21952196
caller, ok := rpccontext.CallerID(ctx)
21962197
if !ok {
2197-
return nil, errors.New("no caller ID on context")
2198+
return nil, errors.New("missing caller ID")
21982199
}
21992200

22002201
if caller != agentID {
22012202
return nil, fmt.Errorf("provided caller id is different to expected")
22022203
}
22032204

2204-
return f.entries, nil
2205+
entries := []api.ReadOnlyEntry{}
2206+
for _, entry := range f.entries {
2207+
entries = append(entries, api.NewReadOnlyEntry(entry))
2208+
}
2209+
2210+
return entries, nil
22052211
}
22062212

22072213
type fakeRateLimiter struct {

pkg/server/authorizedentries/cache.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func (c *Cache) LookupAuthorizedEntries(agentID spiffeid.ID, requestedEntries ma
8888
return foundEntries
8989
}
9090

91-
func (c *Cache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry {
91+
func (c *Cache) GetAuthorizedEntries(agentID spiffeid.ID) []api.ReadOnlyEntry {
9292
c.mu.RLock()
9393
defer c.mu.RUnlock()
9494

pkg/server/authorizedentries/cache_test.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/spiffe/go-spiffe/v2/spiffeid"
1212
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
1313
"github.com/spiffe/spire/pkg/common/idutil"
14+
"github.com/spiffe/spire/pkg/common/protoutil"
1415
"github.com/spiffe/spire/pkg/server/api"
1516
"github.com/spiffe/spire/test/clock"
1617
"github.com/spiffe/spire/test/spiretest"
@@ -432,8 +433,16 @@ func assertAuthorizedEntries(tb testing.TB, cache *Cache, agentID spiffeid.ID, a
432433
return m
433434
}
434435

436+
readOnlyEntriesMap := func(entries []api.ReadOnlyEntry) map[string]*types.Entry {
437+
m := make(map[string]*types.Entry)
438+
for _, entry := range entries {
439+
m[entry.GetId()] = entry.Clone(protoutil.AllTrueEntryMask)
440+
}
441+
return m
442+
}
443+
435444
wantMap := entriesMap(wantEntries)
436-
gotMap := entriesMap(cache.GetAuthorizedEntries(agentID))
445+
gotMap := readOnlyEntriesMap(cache.GetAuthorizedEntries(agentID))
437446

438447
for id, want := range wantMap {
439448
got, ok := gotMap[id]

pkg/server/authorizedentries/entries.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package authorizedentries
22

33
import (
44
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
5-
"google.golang.org/protobuf/proto"
5+
"github.com/spiffe/spire/pkg/server/api"
66
)
77

88
type entryRecord struct {
@@ -29,17 +29,13 @@ func entryRecordByParentID(a, b entryRecord) bool {
2929
}
3030
}
3131

32-
func cloneEntriesFromRecords(entryRecords []entryRecord) []*types.Entry {
32+
func cloneEntriesFromRecords(entryRecords []entryRecord) []api.ReadOnlyEntry {
3333
if len(entryRecords) == 0 {
3434
return nil
3535
}
36-
cloned := make([]*types.Entry, 0, len(entryRecords))
36+
cloned := make([]api.ReadOnlyEntry, 0, len(entryRecords))
3737
for _, entryRecord := range entryRecords {
38-
cloned = append(cloned, cloneEntry(entryRecord.EntryCloneOnly))
38+
cloned = append(cloned, api.NewReadOnlyEntry(entryRecord.EntryCloneOnly))
3939
}
4040
return cloned
4141
}
42-
43-
func cloneEntry(entry *types.Entry) *types.Entry {
44-
return proto.Clone(entry).(*types.Entry)
45-
}

pkg/server/cache/entrycache/fullcache.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"github.com/spiffe/go-spiffe/v2/spiffeid"
88
"github.com/spiffe/spire-api-sdk/proto/spire/api/types"
99
"github.com/spiffe/spire/pkg/server/api"
10-
"google.golang.org/protobuf/proto"
1110
)
1211

1312
var (
@@ -30,7 +29,7 @@ var _ Cache = (*FullEntryCache)(nil)
3029
// at a particular moment in time.
3130
type Cache interface {
3231
LookupAuthorizedEntries(agentID spiffeid.ID, entries map[string]struct{}) map[string]api.ReadOnlyEntry
33-
GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry
32+
GetAuthorizedEntries(agentID spiffeid.ID) []api.ReadOnlyEntry
3433
}
3534

3635
// Selector is a key-value attribute of a node or workload.
@@ -190,13 +189,13 @@ func (c *FullEntryCache) LookupAuthorizedEntries(agentID spiffeid.ID, requestedE
190189
}
191190

192191
// GetAuthorizedEntries gets all authorized registration entries for a given Agent SPIFFE ID.
193-
func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry {
192+
func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []api.ReadOnlyEntry {
194193
seen := allocSeenSet()
195194
defer freeSeenSet(seen)
196195

197-
foundEntries := []*types.Entry{}
196+
foundEntries := []api.ReadOnlyEntry{}
198197
c.crawl(spiffeIDFromID(agentID), seen, func(entry *types.Entry) {
199-
foundEntries = append(foundEntries, proto.Clone(entry).(*types.Entry))
198+
foundEntries = append(foundEntries, api.NewReadOnlyEntry(entry))
200199
})
201200

202201
return foundEntries

0 commit comments

Comments
 (0)