Skip to content

Commit 15329fc

Browse files
committed
fix bugs
1 parent 6207348 commit 15329fc

File tree

5 files changed

+139
-110
lines changed

5 files changed

+139
-110
lines changed

cmd/thv-operator/pkg/controllerutil/tokenexchange.go

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package controllerutil
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76

87
corev1 "k8s.io/api/core/v1"
@@ -12,7 +11,6 @@ import (
1211
mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1"
1312
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
1413
"github.com/stacklok/toolhive/pkg/runner"
15-
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
1614
)
1715

1816
// GenerateOpenTelemetryEnvVars generates OpenTelemetry environment variables
@@ -93,7 +91,8 @@ func GenerateTokenExchangeEnvVars(
9391
}
9492

9593
// AddExternalAuthConfigOptions adds external authentication configuration options to builder options
96-
// This creates middleware configuration for token exchange and is shared between MCPServer and MCPRemoteProxy
94+
// This creates token exchange configuration which will be automatically converted to middleware by
95+
// PopulateMiddlewareConfigs() when the runner starts. This ensures correct middleware ordering.
9796
func AddExternalAuthConfigOptions(
9897
ctx context.Context,
9998
c client.Client,
@@ -138,55 +137,27 @@ func AddExternalAuthConfigOptions(
138137
}
139138
}
140139

141-
// Use scopes array directly from spec
142-
scopes := tokenExchangeSpec.Scopes
143-
144140
// Determine header strategy based on ExternalTokenHeaderName
145141
headerStrategy := "replace" // Default strategy
146142
if tokenExchangeSpec.ExternalTokenHeaderName != "" {
147143
headerStrategy = "custom"
148144
}
149145

150-
// Build token exchange middleware configuration
146+
// Build token exchange configuration
151147
// Client secret is provided via TOOLHIVE_TOKEN_EXCHANGE_CLIENT_SECRET environment variable
152148
// to avoid embedding plaintext secrets in the ConfigMap
153-
tokenExchangeConfig := map[string]interface{}{
154-
"token_url": tokenExchangeSpec.TokenURL,
155-
"client_id": tokenExchangeSpec.ClientID,
156-
"audience": tokenExchangeSpec.Audience,
157-
}
158-
159-
if len(scopes) > 0 {
160-
tokenExchangeConfig["scopes"] = scopes
161-
}
162-
163-
if headerStrategy != "" {
164-
tokenExchangeConfig["header_strategy"] = headerStrategy
165-
}
166-
167-
if tokenExchangeSpec.ExternalTokenHeaderName != "" {
168-
tokenExchangeConfig["external_token_header_name"] = tokenExchangeSpec.ExternalTokenHeaderName
169-
}
170-
171-
// Create middleware parameters
172-
middlewareParams := map[string]interface{}{
173-
"token_exchange_config": tokenExchangeConfig,
174-
}
175-
176-
// Marshal parameters to JSON
177-
paramsJSON, err := json.Marshal(middlewareParams)
178-
if err != nil {
179-
return fmt.Errorf("failed to marshal token exchange middleware parameters: %w", err)
180-
}
181-
182-
// Create middleware config
183-
middlewareConfig := transporttypes.MiddlewareConfig{
184-
Type: tokenexchange.MiddlewareType,
185-
Parameters: json.RawMessage(paramsJSON),
186-
}
187-
188-
// Use WithAppendMiddlewareConfig to append instead of replacing existing middlewares
189-
*options = append(*options, runner.WithAppendMiddlewareConfig(middlewareConfig))
149+
tokenExchangeConfig := &tokenexchange.Config{
150+
TokenURL: tokenExchangeSpec.TokenURL,
151+
ClientID: tokenExchangeSpec.ClientID,
152+
Audience: tokenExchangeSpec.Audience,
153+
Scopes: tokenExchangeSpec.Scopes,
154+
HeaderStrategy: headerStrategy,
155+
ExternalTokenHeaderName: tokenExchangeSpec.ExternalTokenHeaderName,
156+
}
157+
158+
// Use WithTokenExchangeConfig to add configuration
159+
// The middleware will be automatically created by PopulateMiddlewareConfigs() in the correct order
160+
*options = append(*options, runner.WithTokenExchangeConfig(tokenExchangeConfig))
190161

191162
return nil
192163
}

pkg/runner/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/stacklok/toolhive/pkg/auth"
1212
authoauth "github.com/stacklok/toolhive/pkg/auth/oauth"
1313
"github.com/stacklok/toolhive/pkg/auth/remote"
14+
"github.com/stacklok/toolhive/pkg/auth/tokenexchange"
1415
"github.com/stacklok/toolhive/pkg/authz"
1516
"github.com/stacklok/toolhive/pkg/container"
1617
rt "github.com/stacklok/toolhive/pkg/container/runtime"
@@ -100,6 +101,9 @@ type RunConfig struct {
100101
// OIDCConfig contains OIDC configuration
101102
OIDCConfig *auth.TokenValidatorConfig `json:"oidc_config,omitempty" yaml:"oidc_config,omitempty"`
102103

104+
// TokenExchangeConfig contains token exchange configuration for external authentication
105+
TokenExchangeConfig *tokenexchange.Config `json:"token_exchange_config,omitempty" yaml:"token_exchange_config,omitempty"`
106+
103107
// AuthzConfig contains the authorization configuration
104108
AuthzConfig *authz.Config `json:"authz_config,omitempty" yaml:"authz_config,omitempty"`
105109

pkg/runner/config_builder.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,6 @@ func WithMiddlewareConfig(middlewareConfig []types.MiddlewareConfig) RunConfigBu
107107
}
108108
}
109109

110-
// WithAppendMiddlewareConfig appends a middleware configuration to the existing list
111-
func WithAppendMiddlewareConfig(middlewareConfig types.MiddlewareConfig) RunConfigBuilderOption {
112-
return func(b *runConfigBuilder) error {
113-
b.config.MiddlewareConfigs = append(b.config.MiddlewareConfigs, middlewareConfig)
114-
return nil
115-
}
116-
}
117-
118110
// WithCmdArgs sets the command arguments
119111
func WithCmdArgs(args []string) RunConfigBuilderOption {
120112
return func(b *runConfigBuilder) error {
@@ -365,6 +357,14 @@ func WithOIDCConfig(
365357
}
366358
}
367359

360+
// WithTokenExchangeConfig sets the token exchange configuration
361+
func WithTokenExchangeConfig(config *tokenexchange.Config) RunConfigBuilderOption {
362+
return func(b *runConfigBuilder) error {
363+
b.config.TokenExchangeConfig = config
364+
return nil
365+
}
366+
}
367+
368368
// WithTelemetryConfig configures telemetry settings (legacy - custom attributes handled via middleware)
369369
func WithTelemetryConfig(
370370
otelEndpoint string,

pkg/runner/middleware.go

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -29,57 +29,68 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory {
2929
}
3030
}
3131

32-
// hasMiddlewareType checks if a middleware of the given type already exists
33-
func hasMiddlewareType(middlewares []types.MiddlewareConfig, middlewareType string) bool {
34-
for _, m := range middlewares {
35-
if m.Type == middlewareType {
36-
return true
37-
}
38-
}
39-
return false
40-
}
41-
4232
// PopulateMiddlewareConfigs populates the MiddlewareConfigs slice based on the RunConfig settings
4333
// This function serves as a bridge between the old configuration style and the new generic middleware system
44-
// It appends default middlewares to any existing middlewares, avoiding duplicates
34+
//
35+
//nolint:gocyclo // Function complexity is acceptable for middleware configuration
4536
func PopulateMiddlewareConfigs(config *RunConfig) error {
46-
// Start with existing middlewares (may already contain operator-provided middlewares)
47-
middlewareConfigs := config.MiddlewareConfigs
48-
var err error
37+
var middlewareConfigs []types.MiddlewareConfig
4938
// TODO: Consider extracting other middleware setup into helper functions like addUsageMetricsMiddleware
5039

51-
// Authentication middleware (add if not already present)
52-
if !hasMiddlewareType(middlewareConfigs, auth.MiddlewareType) {
53-
authParams := auth.MiddlewareParams{
54-
OIDCConfig: config.OIDCConfig,
55-
}
56-
authConfig, err := types.NewMiddlewareConfig(auth.MiddlewareType, authParams)
57-
if err != nil {
58-
return fmt.Errorf("failed to create auth middleware config: %w", err)
59-
}
60-
middlewareConfigs = append(middlewareConfigs, *authConfig)
40+
// Authentication middleware (always present)
41+
authParams := auth.MiddlewareParams{
42+
OIDCConfig: config.OIDCConfig,
43+
}
44+
authConfig, err := types.NewMiddlewareConfig(auth.MiddlewareType, authParams)
45+
if err != nil {
46+
return fmt.Errorf("failed to create auth middleware config: %w", err)
6147
}
48+
middlewareConfigs = append(middlewareConfigs, *authConfig)
6249

63-
// Tools filter and override middleware (add if enabled and not already present)
64-
hasToolFiltering := len(config.ToolsFilter) > 0 || len(config.ToolsOverride) > 0
65-
if hasToolFiltering && !hasMiddlewareType(middlewareConfigs, mcp.ToolFilterMiddlewareType) {
66-
middlewareConfigs = addToolFilterMiddlewares(
67-
middlewareConfigs,
68-
config.ToolsFilter,
69-
config.ToolsOverride,
70-
)
50+
// Token exchange middleware (if configured)
51+
middlewareConfigs, err = addTokenExchangeMiddleware(middlewareConfigs, config.TokenExchangeConfig)
52+
if err != nil {
53+
return err
7154
}
7255

73-
// MCP Parser middleware (add if not already present)
74-
if !hasMiddlewareType(middlewareConfigs, mcp.ParserMiddlewareType) {
75-
mcpParserParams := mcp.ParserMiddlewareParams{}
76-
mcpParserConfig, err := types.NewMiddlewareConfig(mcp.ParserMiddlewareType, mcpParserParams)
56+
// Tools filter and override middleware (if enabled)
57+
if len(config.ToolsFilter) > 0 || len(config.ToolsOverride) > 0 {
58+
// Prepare overrides map (convert runner.ToolOverride -> mcp.ToolOverride)
59+
overrides := make(map[string]mcp.ToolOverride)
60+
for actualName, tool := range config.ToolsOverride {
61+
overrides[actualName] = mcp.ToolOverride{
62+
Name: tool.Name,
63+
Description: tool.Description,
64+
}
65+
}
66+
67+
// Add tool filter middleware with both filter and overrides
68+
toolFilterParams := mcp.ToolFilterMiddlewareParams{
69+
FilterTools: config.ToolsFilter,
70+
ToolsOverride: overrides,
71+
}
72+
toolFilterConfig, err := types.NewMiddlewareConfig(mcp.ToolFilterMiddlewareType, toolFilterParams)
73+
if err != nil {
74+
return fmt.Errorf("failed to create tool filter middleware config: %w", err)
75+
}
76+
middlewareConfigs = append(middlewareConfigs, *toolFilterConfig)
77+
78+
// Add tool call filter middleware with same params
79+
toolCallFilterConfig, err := types.NewMiddlewareConfig(mcp.ToolCallFilterMiddlewareType, toolFilterParams)
7780
if err != nil {
78-
return fmt.Errorf("failed to create MCP parser middleware config: %w", err)
81+
return fmt.Errorf("failed to create tool call filter middleware config: %w", err)
7982
}
80-
middlewareConfigs = append(middlewareConfigs, *mcpParserConfig)
83+
middlewareConfigs = append(middlewareConfigs, *toolCallFilterConfig)
8184
}
8285

86+
// MCP Parser middleware (always present)
87+
mcpParserParams := mcp.ParserMiddlewareParams{}
88+
mcpParserConfig, err := types.NewMiddlewareConfig(mcp.ParserMiddlewareType, mcpParserParams)
89+
if err != nil {
90+
return fmt.Errorf("failed to create MCP parser middleware config: %w", err)
91+
}
92+
middlewareConfigs = append(middlewareConfigs, *mcpParserConfig)
93+
8394
// Load application config for global settings
8495
configProvider := cfg.NewDefaultProvider()
8596
appConfig := configProvider.GetConfig()
@@ -91,31 +102,74 @@ func PopulateMiddlewareConfigs(config *RunConfig) error {
91102
}
92103

93104
// Telemetry middleware (if enabled)
94-
middlewareConfigs = addTelemetryMiddleware(
95-
middlewareConfigs,
96-
config.TelemetryConfig,
97-
config.Name,
98-
config.Transport.String(),
99-
)
105+
if config.TelemetryConfig != nil {
106+
telemetryParams := telemetry.FactoryMiddlewareParams{
107+
Config: config.TelemetryConfig,
108+
ServerName: config.Name,
109+
Transport: config.Transport.String(),
110+
}
111+
telemetryConfig, err := types.NewMiddlewareConfig(telemetry.MiddlewareType, telemetryParams)
112+
if err != nil {
113+
return fmt.Errorf("failed to create telemetry middleware config: %w", err)
114+
}
115+
middlewareConfigs = append(middlewareConfigs, *telemetryConfig)
116+
}
100117

101118
// Authorization middleware (if enabled)
102-
middlewareConfigs = addAuthzMiddleware(middlewareConfigs, config.AuthzConfigPath)
119+
if config.AuthzConfig != nil {
120+
authzParams := authz.FactoryMiddlewareParams{
121+
ConfigPath: config.AuthzConfigPath, // Keep for backwards compatibility
122+
ConfigData: config.AuthzConfig, // Use the loaded config data
123+
}
124+
authzConfig, err := types.NewMiddlewareConfig(authz.MiddlewareType, authzParams)
125+
if err != nil {
126+
return fmt.Errorf("failed to create authorization middleware config: %w", err)
127+
}
128+
middlewareConfigs = append(middlewareConfigs, *authzConfig)
129+
}
103130

104131
// Audit middleware (if enabled)
105-
enableAudit := config.AuditConfig != nil
106-
middlewareConfigs = addAuditMiddleware(
107-
middlewareConfigs,
108-
enableAudit,
109-
config.AuditConfigPath,
110-
config.Name,
111-
config.Transport.String(),
112-
)
132+
if config.AuditConfig != nil {
133+
auditParams := audit.MiddlewareParams{
134+
ConfigPath: config.AuditConfigPath, // Keep for backwards compatibility
135+
ConfigData: config.AuditConfig, // Use the loaded config data
136+
Component: config.AuditConfig.Component,
137+
TransportType: config.Transport.String(), // Pass the actual transport type
138+
}
139+
auditConfig, err := types.NewMiddlewareConfig(audit.MiddlewareType, auditParams)
140+
if err != nil {
141+
return fmt.Errorf("failed to create audit middleware config: %w", err)
142+
}
143+
middlewareConfigs = append(middlewareConfigs, *auditConfig)
144+
}
113145

114146
// Set the populated middleware configs
115147
config.MiddlewareConfigs = middlewareConfigs
116148
return nil
117149
}
118150

151+
// addTokenExchangeMiddleware adds token exchange middleware if configured
152+
func addTokenExchangeMiddleware(
153+
middlewares []types.MiddlewareConfig,
154+
tokenExchangeConfig *tokenexchange.Config,
155+
) ([]types.MiddlewareConfig, error) {
156+
if tokenExchangeConfig == nil {
157+
return middlewares, nil
158+
}
159+
160+
tokenExchangeParams := tokenexchange.MiddlewareParams{
161+
TokenExchangeConfig: tokenExchangeConfig,
162+
}
163+
tokenExchangeMwConfig, err := types.NewMiddlewareConfig(
164+
tokenexchange.MiddlewareType,
165+
tokenExchangeParams,
166+
)
167+
if err != nil {
168+
return nil, fmt.Errorf("failed to create token exchange middleware config: %w", err)
169+
}
170+
return append(middlewares, *tokenExchangeMwConfig), nil
171+
}
172+
119173
// addUsageMetricsMiddleware adds usage metrics middleware if enabled
120174
func addUsageMetricsMiddleware(middlewares []types.MiddlewareConfig, configDisabled bool) ([]types.MiddlewareConfig, error) {
121175
if !usagemetrics.ShouldEnableMetrics(configDisabled) {

pkg/runner/runner.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ func (c *RunConfig) GetPort() int {
119119
//
120120
//nolint:gocyclo // This function is complex but manageable
121121
func (r *Runner) Run(ctx context.Context) error {
122-
// Populate default middlewares from old config fields
123-
// This now runs ALWAYS and appends to existing middlewares (avoiding duplicates)
124-
// This allows both operator-provided middlewares and default middlewares to coexist
125-
if err := PopulateMiddlewareConfigs(r.Config); err != nil {
126-
return fmt.Errorf("failed to populate middleware configs: %v", err)
122+
// Populate default middlewares from old config fields if not already populated
123+
if len(r.Config.MiddlewareConfigs) == 0 {
124+
if err := PopulateMiddlewareConfigs(r.Config); err != nil {
125+
return fmt.Errorf("failed to populate middleware configs: %v", err)
126+
}
127127
}
128128

129129
// Create transport with runtime

0 commit comments

Comments
 (0)