@@ -7,12 +7,14 @@ import (
77 "net/http"
88 "net/url"
99 "reflect"
10+ "strings"
1011 "time"
1112
1213 "github.com/danielgtaylor/huma/v2"
1314 "github.com/danielgtaylor/huma/v2/adapters/humachi"
1415 "github.com/go-chi/chi/v5"
1516 "github.com/go-chi/chi/v5/middleware"
17+ "github.com/go-chi/cors"
1618 "github.com/hashicorp/go-hclog"
1719
1820 "github.com/mozilla-ai/mcpd/v2/internal/api"
@@ -26,13 +28,17 @@ type ApiServer struct {
2628 healthTracker contracts.MCPHealthMonitor
2729 logger hclog.Logger
2830 addr string
31+ enableCORS bool
32+ corsOrigins []string
2933}
3034
3135func NewApiServer (
3236 logger hclog.Logger ,
3337 accessor contracts.MCPClientAccessor ,
3438 monitor contracts.MCPHealthMonitor ,
3539 addr string ,
40+ enableCORS bool ,
41+ corsOrigins []string ,
3642) (* ApiServer , error ) {
3743 if logger == nil || reflect .ValueOf (logger ).IsNil () {
3844 return nil , fmt .Errorf ("logger cannot be nil" )
@@ -52,13 +58,42 @@ func NewApiServer(
5258 clientManager : accessor ,
5359 healthTracker : monitor ,
5460 addr : addr ,
61+ enableCORS : enableCORS ,
62+ corsOrigins : corsOrigins ,
5563 }, nil
5664}
5765
5866func (a * ApiServer ) Start (ctx context.Context ) error {
5967 // Create router.
6068 mux := chi .NewMux ()
6169 mux .Use (middleware .StripSlashes )
70+
71+ // Add CORS middleware if enabled
72+ if a .enableCORS {
73+ a .logger .Info ("Enabling CORS" , "origins" , a .corsOrigins )
74+
75+ corsOptions := cors.Options {
76+ AllowedOrigins : a .corsOrigins ,
77+ AllowedMethods : []string {"GET" , "POST" , "PUT" , "DELETE" , "OPTIONS" },
78+ AllowedHeaders : []string {"Accept" , "Authorization" , "Content-Type" , "X-CSRF-Token" , "X-Requested-With" },
79+ ExposedHeaders : []string {"Link" },
80+ AllowCredentials : false ,
81+ MaxAge : 300 , // Maximum value not ignored by any of major browsers
82+ }
83+
84+ // Handle wildcard origins properly
85+ for i , origin := range corsOptions .AllowedOrigins {
86+ if origin == "*" {
87+ corsOptions .AllowedOrigins = []string {"*" }
88+ corsOptions .AllowCredentials = false
89+ break
90+ }
91+ corsOptions .AllowedOrigins [i ] = strings .TrimSpace (origin )
92+ }
93+
94+ mux .Use (cors .Handler (corsOptions ))
95+ }
96+
6297 config := huma .DefaultConfig ("mcpd docs" , cmd .Version ())
6398 router := humachi .New (mux , config )
6499
@@ -85,6 +120,9 @@ func (a *ApiServer) Start(ctx context.Context) error {
85120 // Start the API.
86121 go func () {
87122 a .logger .Info ("Starting API server" , "address" , a .addr , "prefix" , apiPathPrefix )
123+ if a .enableCORS {
124+ a .logger .Info ("CORS enabled" , "origins" , a .corsOrigins )
125+ }
88126 if err := srv .ListenAndServe (); err != nil && ! stdErrors .Is (err , http .ErrServerClosed ) {
89127 errCh <- err
90128 }
0 commit comments