|
1 | 1 | package gateway
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "context" |
| 5 | + "net" |
4 | 6 | "net/http"
|
| 7 | + "net/http/httptest" |
5 | 8 | "testing"
|
| 9 | + "time" |
6 | 10 |
|
7 | 11 | "github.com/stretchr/testify/require"
|
8 | 12 | "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
9 | 13 | "go.opentelemetry.io/otel"
|
10 | 14 | "go.opentelemetry.io/otel/propagation"
|
| 15 | + sdktrace "go.opentelemetry.io/otel/sdk/trace" |
| 16 | + "go.opentelemetry.io/otel/sdk/trace/tracetest" |
11 | 17 | "go.opentelemetry.io/otel/trace"
|
12 | 18 | "go.uber.org/goleak"
|
| 19 | + "google.golang.org/grpc" |
| 20 | + "google.golang.org/grpc/health" |
| 21 | + healthpb "google.golang.org/grpc/health/grpc_health_v1" |
13 | 22 |
|
14 | 23 | "github.com/authzed/spicedb/pkg/testutil"
|
15 | 24 | )
|
@@ -48,11 +57,158 @@ func TestOtelForwarding(t *testing.T) {
|
48 | 57 | func TestCloseConnections(t *testing.T) {
|
49 | 58 | defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...)
|
50 | 59 |
|
51 |
| - gatewayHandler, err := NewHandler(t.Context(), "192.0.2.0:4321", "") |
| 60 | + gatewayHandler, err := NewHandler(t.Context(), "192.0.2.0:4321", "", false) |
52 | 61 | require.NoError(t, err)
|
53 | 62 | // 4 conns for permission+schema+watch+experimental services, 1 for health check
|
54 | 63 | require.Len(t, gatewayHandler.closers, 5)
|
55 | 64 |
|
56 | 65 | // if connections are not closed, goleak would detect it
|
57 | 66 | require.NoError(t, gatewayHandler.Close())
|
58 | 67 | }
|
| 68 | + |
| 69 | +func TestGatewayHealthCheckTracing(t *testing.T) { |
| 70 | + tests := []struct { |
| 71 | + name string |
| 72 | + disableHealthCheckTracing bool |
| 73 | + description string |
| 74 | + }{ |
| 75 | + { |
| 76 | + name: "health check tracing disabled", |
| 77 | + disableHealthCheckTracing: true, |
| 78 | + description: "gateway should skip tracing for health checks when flag is true", |
| 79 | + }, |
| 80 | + { |
| 81 | + name: "health check tracing enabled", |
| 82 | + disableHealthCheckTracing: false, |
| 83 | + description: "gateway should trace health checks when flag is false", |
| 84 | + }, |
| 85 | + } |
| 86 | + |
| 87 | + for _, tt := range tests { |
| 88 | + t.Run(tt.name, func(t *testing.T) { |
| 89 | + defer goleak.VerifyNone(t, append(testutil.GoLeakIgnores(), goleak.IgnoreCurrent())...) |
| 90 | + |
| 91 | + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) |
| 92 | + defer cancel() |
| 93 | + |
| 94 | + // Since we can't use a real gRPC server, we need to mock one |
| 95 | + mockServer := setupMockGRPCServer(t, ctx) |
| 96 | + defer mockServer.Stop() |
| 97 | + |
| 98 | + defaultProvider := otel.GetTracerProvider() |
| 99 | + defer otel.SetTracerProvider(defaultProvider) |
| 100 | + |
| 101 | + provider := sdktrace.NewTracerProvider( |
| 102 | + sdktrace.WithSampler(sdktrace.AlwaysSample()), |
| 103 | + ) |
| 104 | + spanrecorder := tracetest.NewSpanRecorder() |
| 105 | + provider.RegisterSpanProcessor(spanrecorder) |
| 106 | + otel.SetTracerProvider(provider) |
| 107 | + |
| 108 | + gatewayHandler, err := NewHandler(ctx, mockServer.Addr(), "", tt.disableHealthCheckTracing) |
| 109 | + require.NoError(t, err) |
| 110 | + defer gatewayHandler.Close() |
| 111 | + |
| 112 | + // Test health check endpoint |
| 113 | + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) |
| 114 | + w := httptest.NewRecorder() |
| 115 | + gatewayHandler.ServeHTTP(w, req) |
| 116 | + require.Equal(t, http.StatusOK, w.Code) |
| 117 | + |
| 118 | + healthSpans := spanrecorder.Ended() |
| 119 | + spanrecorder.Reset() |
| 120 | + |
| 121 | + // Test regular endpoint |
| 122 | + req = httptest.NewRequest(http.MethodGet, "/openapi.json", nil) |
| 123 | + w = httptest.NewRecorder() |
| 124 | + gatewayHandler.ServeHTTP(w, req) |
| 125 | + require.Equal(t, http.StatusOK, w.Code) |
| 126 | + |
| 127 | + regularSpans := spanrecorder.Ended() |
| 128 | + |
| 129 | + if tt.disableHealthCheckTracing { |
| 130 | + require.Len(t, healthSpans, 0, "Health check should not create spans when tracing is disabled") |
| 131 | + |
| 132 | + healthFound := false |
| 133 | + for _, span := range healthSpans { |
| 134 | + if span.Name() == "grpc.health.v1.Health/Check" { |
| 135 | + healthFound = true |
| 136 | + break |
| 137 | + } |
| 138 | + } |
| 139 | + require.False(t, healthFound, "Unexpectedly found health check span when tracing is disabled") |
| 140 | + |
| 141 | + gatewayFound := false |
| 142 | + for _, span := range healthSpans { |
| 143 | + if span.Name() == "gateway" { |
| 144 | + gatewayFound = true |
| 145 | + break |
| 146 | + } |
| 147 | + } |
| 148 | + require.False(t, gatewayFound, "Unexpectedly found gateway span for health check when tracing is disabled") |
| 149 | + } else { |
| 150 | + require.True(t, len(healthSpans) > 0, "Health check should create spans when tracing is enabled") |
| 151 | + |
| 152 | + healthFound := false |
| 153 | + for _, span := range healthSpans { |
| 154 | + if span.Name() == "grpc.health.v1.Health/Check" { |
| 155 | + healthFound = true |
| 156 | + break |
| 157 | + } |
| 158 | + } |
| 159 | + require.True(t, healthFound, "Expected to find health check span when tracing is enabled") |
| 160 | + } |
| 161 | + |
| 162 | + require.True(t, len(regularSpans) > 0, "Regular endpoints should create spans") |
| 163 | + regularFound := false |
| 164 | + for _, span := range regularSpans { |
| 165 | + if span.Name() == "gateway" { |
| 166 | + regularFound = true |
| 167 | + break |
| 168 | + } |
| 169 | + } |
| 170 | + require.True(t, regularFound, "Expected to find gateway span for regular endpoint") |
| 171 | + }) |
| 172 | + } |
| 173 | +} |
| 174 | + |
| 175 | +type mockGRPCServer struct { |
| 176 | + server *grpc.Server |
| 177 | + listener net.Listener |
| 178 | + addr string |
| 179 | +} |
| 180 | + |
| 181 | +func (m *mockGRPCServer) Addr() string { |
| 182 | + return m.addr |
| 183 | +} |
| 184 | + |
| 185 | +func (m *mockGRPCServer) Stop() { |
| 186 | + if m.server != nil { |
| 187 | + m.server.Stop() |
| 188 | + } |
| 189 | + if m.listener != nil { |
| 190 | + m.listener.Close() |
| 191 | + } |
| 192 | +} |
| 193 | + |
| 194 | +func setupMockGRPCServer(t *testing.T, ctx context.Context) *mockGRPCServer { |
| 195 | + listener, err := net.Listen("tcp", "127.0.0.1:0") |
| 196 | + require.NoError(t, err) |
| 197 | + |
| 198 | + server := grpc.NewServer() |
| 199 | + |
| 200 | + healthServer := health.NewServer() |
| 201 | + healthpb.RegisterHealthServer(server, healthServer) |
| 202 | + |
| 203 | + go func() { |
| 204 | + if err := server.Serve(listener); err != nil { |
| 205 | + t.Logf("Mock gRPC server error: %v", err) |
| 206 | + } |
| 207 | + }() |
| 208 | + |
| 209 | + return &mockGRPCServer{ |
| 210 | + server: server, |
| 211 | + listener: listener, |
| 212 | + addr: listener.Addr().String(), |
| 213 | + } |
| 214 | +} |
0 commit comments