Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ func NewAPI(treeID uint) (*API, error) {

cachedCheckpoints := make(map[int64]string)
for _, r := range ranges.GetInactive() {
tc := trillianclient.NewTrillianClient(ctx, logClient, r.TreeID)
resp := tc.GetLatest(0)
tc := trillianclient.NewTrillianClient(logClient, r.TreeID)
resp := tc.GetLatest(ctx, 0)
if resp.Status != codes.OK {
return nil, fmt.Errorf("error fetching latest tree head for inactive shard %d: resp code is %d, err is %w", r.TreeID, resp.Status, resp.Err)
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/api/entries.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ func createLogEntry(params entries.CreateLogEntryParams) (models.LogEntry, middl
return nil, handleRekorAPIError(params, http.StatusInternalServerError, err, failedToGenerateCanonicalEntry)
}

tc := trillianclient.NewTrillianClient(ctx, api.logClient, api.treeID)
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)

resp := tc.AddLeaf(leaf)
resp := tc.AddLeaf(ctx, leaf)
// this represents overall GRPC response state (not the results of insertion into the log)
if resp.Status != codes.OK {
return nil, handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianUnexpectedResult)
Expand Down Expand Up @@ -622,8 +622,8 @@ func SearchLogQueryHandler(params entries.SearchLogQueryParams) middleware.Respo
for i, hash := range searchHashes {
var results map[int64]*trillian.GetEntryAndProofResponse
for _, shard := range api.logRanges.AllShards() {
tcs := trillianclient.NewTrillianClient(httpReqCtx, api.logClient, shard)
resp := tcs.GetLeafAndProofByHash(hash)
tcs := trillianclient.NewTrillianClient(api.logClient, shard)
resp := tcs.GetLeafAndProofByHash(httpReqCtx, hash)
switch resp.Status {
case codes.OK:
leafResult := resp.GetLeafAndProofResult
Expand Down Expand Up @@ -677,10 +677,10 @@ func retrieveLogEntryByIndex(ctx context.Context, logIndex int) (models.LogEntry
log.ContextLogger(ctx).Infof("Retrieving log entry by index %d", logIndex)

tid, resolvedIndex := api.logRanges.ResolveVirtualIndex(logIndex)
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
tc := trillianclient.NewTrillianClient(api.logClient, tid)
log.ContextLogger(ctx).Debugf("Retrieving resolved index %v from TreeID %v", resolvedIndex, tid)

resp := tc.GetLeafAndProofByIndex(resolvedIndex)
resp := tc.GetLeafAndProofByIndex(ctx, resolvedIndex)
switch resp.Status {
case codes.OK:
case codes.NotFound, codes.OutOfRange, codes.InvalidArgument:
Expand Down Expand Up @@ -744,10 +744,10 @@ func retrieveUUIDFromTree(ctx context.Context, uuid string, tid int64) (models.L
return models.LogEntry{}, &types.InputValidationError{Err: fmt.Errorf("parsing UUID: %w", err)}
}

tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
tc := trillianclient.NewTrillianClient(api.logClient, tid)
log.ContextLogger(ctx).Debugf("Attempting to retrieve UUID %v from TreeID %v", uuid, tid)

resp := tc.GetLeafAndProofByHash(hashValue)
resp := tc.GetLeafAndProofByHash(ctx, hashValue)
switch resp.Status {
case codes.OK:
result := resp.GetLeafAndProofResult
Expand Down
22 changes: 12 additions & 10 deletions pkg/api/tlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ import (

// GetLogInfoHandler returns the current size of the tree and the STH
func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)
ctx := params.HTTPRequest.Context()
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)

// for each inactive shard, get the loginfo
var inactiveShards []*models.InactiveShardLogInfo
for _, shard := range api.logRanges.GetInactive() {
// Get details for this inactive shard
is, err := inactiveShardLogInfo(params.HTTPRequest.Context(), shard.TreeID, api.cachedCheckpoints)
is, err := inactiveShardLogInfo(ctx, shard.TreeID, api.cachedCheckpoints)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("inactive shard error: %w", err), unexpectedInactiveShardError)
}
Expand All @@ -53,7 +54,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
if swag.BoolValue(params.Stable) && redisClient != nil {
// key is treeID/latest
key := fmt.Sprintf("%d/latest", api.logRanges.GetActive().TreeID)
redisResult, err := redisClient.Get(params.HTTPRequest.Context(), key).Result()
redisResult, err := redisClient.Get(ctx, key).Result()
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError,
fmt.Errorf("error getting checkpoint from redis: %w", err), "error getting checkpoint from redis")
Expand Down Expand Up @@ -82,7 +83,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
return tlog.NewGetLogInfoOK().WithPayload(&logInfo)
}

resp := tc.GetLatest(0)
resp := tc.GetLatest(ctx, 0)
if resp.Status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
}
Expand All @@ -96,7 +97,7 @@ func GetLogInfoHandler(params tlog.GetLogInfoParams) middleware.Responder {
hashString := hex.EncodeToString(root.RootHash)
treeSize := int64(root.TreeSize)

scBytes, err := util.CreateAndSignCheckpoint(params.HTTPRequest.Context(),
scBytes, err := util.CreateAndSignCheckpoint(ctx,
viper.GetString("rekor_server.hostname"), api.logRanges.GetActive().TreeID, root.TreeSize, root.RootHash, api.logRanges.GetActive().Signer)
if err != nil {
return handleRekorAPIError(params, http.StatusInternalServerError, err, sthGenerateError)
Expand All @@ -123,17 +124,18 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
errMsg := fmt.Sprintf(firstSizeLessThanLastSize, *params.FirstSize, params.LastSize)
return handleRekorAPIError(params, http.StatusBadRequest, fmt.Errorf("consistency proof: %s", errMsg), errMsg)
}
tc := trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, api.treeID)
ctx := params.HTTPRequest.Context()
tc := trillianclient.NewTrillianClient(api.logClient, api.treeID)
if treeID := swag.StringValue(params.TreeID); treeID != "" {
id, err := strconv.Atoi(treeID)
if err != nil {
log.Logger.Infof("Unable to convert %s to string, skipping initializing client with Tree ID: %v", treeID, err)
} else {
tc = trillianclient.NewTrillianClient(params.HTTPRequest.Context(), api.logClient, int64(id))
tc = trillianclient.NewTrillianClient(api.logClient, int64(id))
}
}

resp := tc.GetConsistencyProof(*params.FirstSize, params.LastSize)
resp := tc.GetConsistencyProof(ctx, *params.FirstSize, params.LastSize)
if resp.Status != codes.OK {
return handleRekorAPIError(params, http.StatusInternalServerError, fmt.Errorf("grpc error: %w", resp.Err), trillianCommunicationError)
}
Expand Down Expand Up @@ -168,8 +170,8 @@ func GetLogProofHandler(params tlog.GetLogProofParams) middleware.Responder {
}

func inactiveShardLogInfo(ctx context.Context, tid int64, cachedCheckpoints map[int64]string) (*models.InactiveShardLogInfo, error) {
tc := trillianclient.NewTrillianClient(ctx, api.logClient, tid)
resp := tc.GetLatest(0)
tc := trillianclient.NewTrillianClient(api.logClient, tid)
resp := tc.GetLatest(ctx, 0)
if resp.Status != codes.OK {
return nil, fmt.Errorf("resp code is %d", resp.Status)
}
Expand Down
60 changes: 22 additions & 38 deletions pkg/trillianclient/trillian_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,15 @@ import (

// TrillianClient provides a wrapper around the Trillian client
type TrillianClient struct {
client trillian.TrillianLogClient
logID int64
context context.Context
client trillian.TrillianLogClient
logID int64
}

// NewTrillianClient creates a TrillianClient with the given Trillian client and log/tree ID.
func NewTrillianClient(ctx context.Context, logClient trillian.TrillianLogClient, logID int64) TrillianClient {
func NewTrillianClient(logClient trillian.TrillianLogClient, logID int64) TrillianClient {
return TrillianClient{
client: logClient,
logID: logID,
context: ctx,
client: logClient,
logID: logID,
}
}

Expand Down Expand Up @@ -76,26 +74,26 @@ func unmarshalLogRoot(logRoot []byte) (types.LogRootV1, error) {
return root, nil
}

func (t *TrillianClient) root() (types.LogRootV1, error) {
func (t *TrillianClient) root(ctx context.Context) (types.LogRootV1, error) {
rqst := &trillian.GetLatestSignedLogRootRequest{
LogId: t.logID,
}
resp, err := t.client.GetLatestSignedLogRoot(t.context, rqst)
resp, err := t.client.GetLatestSignedLogRoot(ctx, rqst)
if err != nil {
return types.LogRootV1{}, err
}
return unmarshalLogRoot(resp.SignedLogRoot.LogRoot)
}

func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
func (t *TrillianClient) AddLeaf(ctx context.Context, byteValue []byte) *Response {
leaf := &trillian.LogLeaf{
LeafValue: byteValue,
}
rqst := &trillian.QueueLeafRequest{
LogId: t.logID,
Leaf: leaf,
}
resp, err := t.client.QueueLeaf(t.context, rqst)
resp, err := t.client.QueueLeaf(ctx, rqst)

// check for error
if err != nil || (resp.QueuedLeaf.Status != nil && resp.QueuedLeaf.Status.Code != int32(codes.OK)) {
Expand All @@ -106,7 +104,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
}
}

root, err := t.root()
root, err := t.root(ctx)
if err != nil {
return &Response{
Status: status.Code(err),
Expand All @@ -131,7 +129,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
for {
root = *logClient.GetRoot()
if root.TreeSize >= 1 {
proofResp := t.getProofByHash(resp.QueuedLeaf.Leaf.MerkleLeafHash)
proofResp := t.getProofByHash(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash)
// if this call succeeds or returns an error other than "not found", return
if proofResp.Err == nil || (proofResp.Err != nil && status.Code(proofResp.Err) != codes.NotFound) {
return proofResp
Expand All @@ -148,7 +146,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
}
}

proofResp := waitForInclusion(t.context, resp.QueuedLeaf.Leaf.MerkleLeafHash)
proofResp := waitForInclusion(ctx, resp.QueuedLeaf.Leaf.MerkleLeafHash)
if proofResp.Err != nil {
return &Response{
Status: status.Code(proofResp.Err),
Expand All @@ -168,7 +166,7 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
}

leafIndex := proofs[0].LeafIndex
leafResp := t.GetLeafAndProofByIndex(leafIndex)
leafResp := t.GetLeafAndProofByIndex(ctx, leafIndex)
if leafResp.Err != nil {
return &Response{
Status: status.Code(leafResp.Err),
Expand All @@ -189,9 +187,9 @@ func (t *TrillianClient) AddLeaf(byteValue []byte) *Response {
}
}

func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
func (t *TrillianClient) GetLeafAndProofByHash(ctx context.Context, hash []byte) *Response {
// get inclusion proof for hash, extract index, then fetch leaf using index
proofResp := t.getProofByHash(hash)
proofResp := t.getProofByHash(ctx, hash)
if proofResp.Err != nil {
return &Response{
Status: status.Code(proofResp.Err),
Expand All @@ -208,14 +206,11 @@ func (t *TrillianClient) GetLeafAndProofByHash(hash []byte) *Response {
}
}

return t.GetLeafAndProofByIndex(proofs[0].LeafIndex)
return t.GetLeafAndProofByIndex(ctx, proofs[0].LeafIndex)
}

func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

rootResp := t.GetLatest(0)
func (t *TrillianClient) GetLeafAndProofByIndex(ctx context.Context, index int64) *Response {
rootResp := t.GetLatest(ctx, 0)
if rootResp.Err != nil {
return &Response{
Status: status.Code(rootResp.Err),
Expand Down Expand Up @@ -262,11 +257,7 @@ func (t *TrillianClient) GetLeafAndProofByIndex(index int64) *Response {
}
}

func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {

ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

func (t *TrillianClient) GetLatest(ctx context.Context, leafSizeInt int64) *Response {
resp, err := t.client.GetLatestSignedLogRoot(ctx,
&trillian.GetLatestSignedLogRootRequest{
LogId: t.logID,
Expand All @@ -280,11 +271,7 @@ func (t *TrillianClient) GetLatest(leafSizeInt int64) *Response {
}
}

func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Response {

ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

func (t *TrillianClient) GetConsistencyProof(ctx context.Context, firstSize, lastSize int64) *Response {
resp, err := t.client.GetConsistencyProof(ctx,
&trillian.GetConsistencyProofRequest{
LogId: t.logID,
Expand All @@ -299,11 +286,8 @@ func (t *TrillianClient) GetConsistencyProof(firstSize, lastSize int64) *Respons
}
}

func (t *TrillianClient) getProofByHash(hashValue []byte) *Response {
ctx, cancel := context.WithTimeout(t.context, 20*time.Second)
defer cancel()

rootResp := t.GetLatest(0)
func (t *TrillianClient) getProofByHash(ctx context.Context, hashValue []byte) *Response {
rootResp := t.GetLatest(ctx, 0)
if rootResp.Err != nil {
return &Response{
Status: status.Code(rootResp.Err),
Expand Down
4 changes: 2 additions & 2 deletions pkg/witness/publish_checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewCheckpointPublisher(ctx context.Context,
// before publishing the latest checkpoint. If this occurs due to a sporadic failure, this simply
// means that a witness will not see a fresh checkpoint for an additional period.
func (c *CheckpointPublisher) StartPublisher(ctx context.Context) {
tc := trillianclient.NewTrillianClient(context.Background(), c.logClient, c.treeID)
tc := trillianclient.NewTrillianClient(c.logClient, c.treeID)
sTreeID := strconv.FormatInt(c.treeID, 10)

// publish on startup to ensure a checkpoint is available the first time Rekor starts up
Expand All @@ -103,7 +103,7 @@ func (c *CheckpointPublisher) StartPublisher(ctx context.Context) {
// publish publishes the latest checkpoint to Redis once
func (c *CheckpointPublisher) publish(tc *trillianclient.TrillianClient, sTreeID string) {
// get latest checkpoint
resp := tc.GetLatest(0)
resp := tc.GetLatest(context.Background(), 0)
if resp.Status != codes.OK {
c.reqCounter.With(
map[string]string{
Expand Down