Skip to content

Commit bac985d

Browse files
committed
Refactor common bridge protocol code for reuse
- Move common bridge protocol definitions to subpackage under internal/gcs - Move helper functions to internal/bridgeutils pkg so that they can be used by gcs-sidecar as well Signed-off-by: Kirtana Ashok <[email protected]>
1 parent eec851c commit bac985d

File tree

23 files changed

+364
-369
lines changed

23 files changed

+364
-369
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package commonutils
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"io"
7+
"math"
8+
"strconv"
9+
10+
"github.com/Microsoft/hcsshim/internal/bridgeutils/gcserr"
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
type ErrorRecord struct {
15+
Result int32 // HResult
16+
Message string
17+
StackTrace string `json:",omitempty"`
18+
ModuleName string
19+
FileName string
20+
Line uint32
21+
FunctionName string `json:",omitempty"`
22+
}
23+
24+
// UnmarshalJSONWithHresult unmarshals the given data into the given interface, and
25+
// wraps any error returned in an HRESULT error.
26+
func UnmarshalJSONWithHresult(data []byte, v interface{}) error {
27+
if err := json.Unmarshal(data, v); err != nil {
28+
return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON)
29+
}
30+
return nil
31+
}
32+
33+
// DecodeJSONWithHresult decodes the JSON from the given reader into the given
34+
// interface, and wraps any error returned in an HRESULT error.
35+
func DecodeJSONWithHresult(r io.Reader, v interface{}) error {
36+
if err := json.NewDecoder(r).Decode(v); err != nil {
37+
return gcserr.WrapHresult(err, gcserr.HrVmcomputeInvalidJSON)
38+
}
39+
return nil
40+
}
41+
42+
func SetErrorForResponseBaseUtil(errForResponse error, moduleName string) (hresult gcserr.Hresult, errorMessage string, newRecord ErrorRecord) {
43+
errorMessage = errForResponse.Error()
44+
stackString := ""
45+
fileName := ""
46+
// We use -1 as a sentinel if no line number found (or it cannot be parsed),
47+
// but that will ultimately end up as [math.MaxUint32], so set it to that explicitly.
48+
// (Still keep using -1 for backwards compatibility ...)
49+
lineNumber := uint32(math.MaxUint32)
50+
functionName := ""
51+
if stack := gcserr.BaseStackTrace(errForResponse); stack != nil {
52+
bottomFrame := stack[0]
53+
stackString = fmt.Sprintf("%+v", stack)
54+
fileName = fmt.Sprintf("%s", bottomFrame)
55+
lineNumberStr := fmt.Sprintf("%d", bottomFrame)
56+
if n, err := strconv.ParseUint(lineNumberStr, 10, 32); err == nil {
57+
lineNumber = uint32(n)
58+
} else {
59+
logrus.WithFields(logrus.Fields{
60+
"line-number": lineNumberStr,
61+
logrus.ErrorKey: err,
62+
}).Error("opengcs::bridge::setErrorForResponseBase - failed to parse line number, using -1 instead")
63+
}
64+
functionName = fmt.Sprintf("%n", bottomFrame)
65+
}
66+
hresult, err := gcserr.GetHresult(errForResponse)
67+
if err != nil {
68+
// Default to using the generic failure HRESULT.
69+
hresult = gcserr.HrFail
70+
}
71+
72+
newRecord = ErrorRecord{
73+
Result: int32(hresult),
74+
Message: errorMessage,
75+
StackTrace: stackString,
76+
ModuleName: moduleName,
77+
FileName: fileName,
78+
Line: lineNumber,
79+
FunctionName: functionName,
80+
}
81+
82+
return hresult, errorMessage, newRecord
83+
}
File renamed without changes.

internal/gcs/bridge.go

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"go.opencensus.io/trace"
2020
"golang.org/x/sys/windows"
2121

22+
"github.com/Microsoft/hcsshim/internal/gcs/prot"
2223
"github.com/Microsoft/hcsshim/internal/log"
2324
"github.com/Microsoft/hcsshim/internal/oc"
2425
)
@@ -38,16 +39,16 @@ const (
3839
)
3940

4041
type requestMessage interface {
41-
Base() *requestBase
42+
Base() *prot.RequestBase
4243
}
4344

4445
type responseMessage interface {
45-
Base() *responseBase
46+
Base() *prot.ResponseBase
4647
}
4748

4849
// rpc represents an outstanding rpc request to the guest
4950
type rpc struct {
50-
proc rpcProc
51+
proc prot.RpcProc
5152
id int64
5253
req requestMessage
5354
resp responseMessage
@@ -80,7 +81,7 @@ const (
8081
bridgeFailureTimeout = time.Minute * 5
8182
)
8283

83-
type notifyFunc func(*containerNotification) error
84+
type notifyFunc func(*prot.ContainerNotification) error
8485

8586
// newBridge returns a bridge on `conn`. It calls `notify` when a
8687
// notification message arrives from the guest. It logs transport errors and
@@ -143,7 +144,7 @@ func (brdg *bridge) Wait() error {
143144
// AsyncRPC sends an RPC request to the guest but does not wait for a response.
144145
// If the message cannot be sent before the context is done, then an error is
145146
// returned.
146-
func (brdg *bridge) AsyncRPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage) (*rpc, error) {
147+
func (brdg *bridge) AsyncRPC(ctx context.Context, proc prot.RpcProc, req requestMessage, resp responseMessage) (*rpc, error) {
147148
call := &rpc{
148149
ch: make(chan struct{}),
149150
proc: proc,
@@ -224,7 +225,7 @@ func (call *rpc) Wait() {
224225
// If allowCancel is set and the context becomes done, returns an error without
225226
// waiting for a response. Avoid this on messages that are not idempotent or
226227
// otherwise safe to ignore the response of.
227-
func (brdg *bridge) RPC(ctx context.Context, proc rpcProc, req requestMessage, resp responseMessage, allowCancel bool) error {
228+
func (brdg *bridge) RPC(ctx context.Context, proc prot.RpcProc, req requestMessage, resp responseMessage, allowCancel bool) error {
228229
call, err := brdg.AsyncRPC(ctx, proc, req, resp)
229230
if err != nil {
230231
return err
@@ -261,7 +262,7 @@ func (brdg *bridge) recvLoopRoutine() {
261262
}
262263
}
263264

264-
func readMessage(r io.Reader) (int64, msgType, []byte, error) {
265+
func readMessage(r io.Reader) (int64, prot.MsgType, []byte, error) {
265266
_, span := oc.StartSpan(context.Background(), "bridge receive read message", oc.WithClientSpanKind)
266267
defer span.End()
267268

@@ -270,7 +271,7 @@ func readMessage(r io.Reader) (int64, msgType, []byte, error) {
270271
if err != nil {
271272
return 0, 0, nil, fmt.Errorf("header read: %w", err)
272273
}
273-
typ := msgType(binary.LittleEndian.Uint32(h[hdrOffType:]))
274+
typ := prot.MsgType(binary.LittleEndian.Uint32(h[hdrOffType:]))
274275
n := binary.LittleEndian.Uint32(h[hdrOffSize:])
275276
id := int64(binary.LittleEndian.Uint64(h[hdrOffID:]))
276277
span.AddAttributes(
@@ -311,8 +312,8 @@ func (brdg *bridge) recvLoop() error {
311312
"type": typ.String(),
312313
"message-id": id}).Trace("bridge receive")
313314

314-
switch typ & msgTypeMask {
315-
case msgTypeResponse:
315+
switch typ & prot.MsgTypeMask {
316+
case prot.MsgTypeResponse:
316317
// Find the request associated with this response.
317318
brdg.mu.Lock()
318319
call := brdg.rpcs[id]
@@ -344,11 +345,11 @@ func (brdg *bridge) recvLoop() error {
344345
return err
345346
}
346347

347-
case msgTypeNotify:
348-
if typ != notifyContainer|msgTypeNotify {
348+
case prot.MsgTypeNotify:
349+
if typ != prot.NotifyContainer|prot.MsgTypeNotify {
349350
return fmt.Errorf("bridge received unknown unknown notification message %s", typ)
350351
}
351-
var ntf containerNotification
352+
var ntf prot.ContainerNotification
352353
ntf.ResultInfo.Value = &json.RawMessage{}
353354
err := json.Unmarshal(b, &ntf)
354355
if err != nil {
@@ -383,7 +384,7 @@ func (brdg *bridge) sendLoop() {
383384
}
384385
}
385386

386-
func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgType, id int64, req interface{}) error {
387+
func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ prot.MsgType, id int64, req interface{}) error {
387388
var err error
388389
_, span := oc.StartSpan(context.Background(), "bridge send", oc.WithClientSpanKind)
389390
defer span.End()
@@ -408,9 +409,9 @@ func (brdg *bridge) writeMessage(buf *bytes.Buffer, enc *json.Encoder, typ msgTy
408409
b := buf.Bytes()[hdrSize:]
409410
switch typ {
410411
// container environment vars are in rpCreate for linux; rpcExecuteProcess for windows
411-
case msgType(rpcCreate) | msgTypeRequest:
412+
case prot.MsgType(prot.RpcCreate) | prot.MsgTypeRequest:
412413
b, err = log.ScrubBridgeCreate(b)
413-
case msgType(rpcExecuteProcess) | msgTypeRequest:
414+
case prot.MsgType(prot.RpcExecuteProcess) | prot.MsgTypeRequest:
414415
b, err = log.ScrubBridgeExecProcess(b)
415416
}
416417
if err != nil {
@@ -443,7 +444,7 @@ func (brdg *bridge) sendRPC(buf *bytes.Buffer, enc *json.Encoder, call *rpc) err
443444
brdg.rpcs[id] = call
444445
brdg.nextID++
445446
brdg.mu.Unlock()
446-
typ := msgType(call.proc) | msgTypeRequest
447+
typ := prot.MsgType(call.proc) | prot.MsgTypeRequest
447448
err := brdg.writeMessage(buf, enc, typ, id, call.req)
448449
if err != nil {
449450
// Try to reclaim this request and fail it.

internal/gcs/bridge_test.go

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414
"time"
1515

16+
"github.com/Microsoft/hcsshim/internal/gcs/prot"
1617
"github.com/sirupsen/logrus"
1718
)
1819

@@ -33,7 +34,7 @@ func pipeConn() (*stitched, *stitched) {
3334
return &stitched{r1, w2}, &stitched{r2, w1}
3435
}
3536

36-
func sendMessage(t *testing.T, w io.Writer, typ msgType, id int64, msg []byte) {
37+
func sendMessage(t *testing.T, w io.Writer, typ prot.MsgType, id int64, msg []byte) {
3738
t.Helper()
3839
var h [16]byte
3940
binary.LittleEndian.PutUint32(h[:], uint32(typ))
@@ -63,18 +64,18 @@ func reflector(t *testing.T, rw io.ReadWriteCloser, delay time.Duration) {
6364
return
6465
}
6566
time.Sleep(delay) // delay is used to test timeouts (when non-zero)
66-
typ ^= msgTypeResponse ^ msgTypeRequest
67+
typ ^= prot.MsgTypeResponse ^ prot.MsgTypeRequest
6768
sendMessage(t, rw, typ, id, msg)
6869
}
6970
}
7071

7172
type testReq struct {
72-
requestBase
73+
prot.RequestBase
7374
X, Y int
7475
}
7576

7677
type testResp struct {
77-
responseBase
78+
prot.ResponseBase
7879
X, Y int
7980
}
8081

@@ -92,7 +93,7 @@ func TestBridgeRPC(t *testing.T) {
9293
defer b.Close()
9394
req := testReq{X: 5}
9495
var resp testResp
95-
err := b.RPC(context.Background(), rpcCreate, &req, &resp, false)
96+
err := b.RPC(context.Background(), prot.RpcCreate, &req, &resp, false)
9697
if err != nil {
9798
t.Fatal(err)
9899
}
@@ -107,7 +108,7 @@ func TestBridgeRPCResponseTimeout(t *testing.T) {
107108
b.Timeout = time.Millisecond * 100
108109
req := testReq{X: 5}
109110
var resp testResp
110-
err := b.RPC(context.Background(), rpcCreate, &req, &resp, false)
111+
err := b.RPC(context.Background(), prot.RpcCreate, &req, &resp, false)
111112
if err == nil || !strings.Contains(err.Error(), "bridge closed") {
112113
t.Fatalf("expected bridge disconnection, got %s", err)
113114
}
@@ -121,7 +122,7 @@ func TestBridgeRPCContextDone(t *testing.T) {
121122
defer cancel()
122123
req := testReq{X: 5}
123124
var resp testResp
124-
err := b.RPC(ctx, rpcCreate, &req, &resp, true)
125+
err := b.RPC(ctx, prot.RpcCreate, &req, &resp, true)
125126
if err != context.DeadlineExceeded { //nolint:errorlint
126127
t.Fatalf("expected deadline exceeded, got %s", err)
127128
}
@@ -135,7 +136,7 @@ func TestBridgeRPCContextDoneNoCancel(t *testing.T) {
135136
defer cancel()
136137
req := testReq{X: 5}
137138
var resp testResp
138-
err := b.RPC(ctx, rpcCreate, &req, &resp, false)
139+
err := b.RPC(ctx, prot.RpcCreate, &req, &resp, false)
139140
if err == nil || !strings.Contains(err.Error(), "bridge closed") {
140141
t.Fatalf("expected bridge disconnection, got %s", err)
141142
}
@@ -145,13 +146,13 @@ func TestBridgeRPCBridgeClosed(t *testing.T) {
145146
b := startReflectedBridge(t, 0)
146147
eerr := errors.New("forcibly terminated")
147148
b.kill(eerr)
148-
err := b.RPC(context.Background(), rpcCreate, nil, nil, false)
149+
err := b.RPC(context.Background(), prot.RpcCreate, nil, nil, false)
149150
if err != eerr { //nolint:errorlint
150151
t.Fatal("unexpected: ", err)
151152
}
152153
}
153154

154-
func sendJSON(t *testing.T, w io.Writer, typ msgType, id int64, msg interface{}) error {
155+
func sendJSON(t *testing.T, w io.Writer, typ prot.MsgType, id int64, msg interface{}) error {
155156
t.Helper()
156157
msgb, err := json.Marshal(msg)
157158
if err != nil {
@@ -161,7 +162,7 @@ func sendJSON(t *testing.T, w io.Writer, typ msgType, id int64, msg interface{})
161162
return nil
162163
}
163164

164-
func notifyThroughBridge(t *testing.T, typ msgType, msg interface{}, fn notifyFunc) error {
165+
func notifyThroughBridge(t *testing.T, typ prot.MsgType, msg interface{}, fn notifyFunc) error {
165166
t.Helper()
166167
s, c := pipeConn()
167168
b := newBridge(s, fn, logrus.NewEntry(logrus.StandardLogger()))
@@ -176,9 +177,9 @@ func notifyThroughBridge(t *testing.T, typ msgType, msg interface{}, fn notifyFu
176177
}
177178

178179
func TestBridgeNotify(t *testing.T) {
179-
ntf := &containerNotification{Operation: "testing"}
180+
ntf := &prot.ContainerNotification{Operation: "testing"}
180181
recvd := false
181-
err := notifyThroughBridge(t, msgTypeNotify|notifyContainer, ntf, func(nntf *containerNotification) error {
182+
err := notifyThroughBridge(t, prot.MsgTypeNotify|prot.NotifyContainer, ntf, func(nntf *prot.ContainerNotification) error {
182183
if !reflect.DeepEqual(ntf, nntf) {
183184
t.Errorf("%+v != %+v", ntf, nntf)
184185
}
@@ -194,9 +195,9 @@ func TestBridgeNotify(t *testing.T) {
194195
}
195196

196197
func TestBridgeNotifyFailure(t *testing.T) {
197-
ntf := &containerNotification{Operation: "testing"}
198+
ntf := &prot.ContainerNotification{Operation: "testing"}
198199
errMsg := "notify should have failed"
199-
err := notifyThroughBridge(t, msgTypeNotify|notifyContainer, ntf, func(nntf *containerNotification) error {
200+
err := notifyThroughBridge(t, prot.MsgTypeNotify|prot.NotifyContainer, ntf, func(nntf *prot.ContainerNotification) error {
200201
return errors.New(errMsg)
201202
})
202203
if err == nil || !strings.Contains(err.Error(), errMsg) {

0 commit comments

Comments
 (0)