Skip to content

Commit c68e252

Browse files
authored
refactor: new caller from interfaces & new service from its option (#883)
# Description 1. A new `Service` struct has been introduced, which is responsible for handling the main logic. The `Service` struct has two main methods: `GetChatCompletions()` and `GetInvoke()` that handle the main logic. 2. The `CallerProvider` struct has been removed. 3. The `Service` struct has an additional method called `LoadOrCreateCaller()`, which is responsible for managing the creation and caching of the `Caller` instances. 4. The `NewCaller` function, which is used to create a new `Caller` instance, has a new parameter signature: `func (source yomo.Source, reducer yomo.StreamFunction, md metadata.M, callTimeout time.Duration) *Caller`. This means that when creating a `Service` instance, you need to inject the way to create the source, reducer, and how to exchange metadata. To facilitate this, a new struct called `ServiceOption` has been defined. The `service.LoadOrCreateCaller()` method will call the `ServiceOption.SourceBuilder()`, `ServiceOption.ReducerBuilder()`, `ServiceOption.MetadataExchanger() ` and then use the returned values as the parameters for the `NewCaller` function to create the `Caller` instance. the `ServiceOption` struct: ```go // ServiceOptions is the option for creating service type ServiceOptions struct { // Logger is the logger for the service Logger *slog.Logger // Tracer is the tracer for the service Tracer trace.Tracer // CredentialFunc is the function for getting the credential from the request CredentialFunc func(r *http.Request) (string, error) // CallerCacheSize is the size of the caller's cache CallerCacheSize int // CallerCacheTTL is the time to live of the callers cache CallerCacheTTL time.Duration // CallerCallTimeout is the timeout for awaiting the function response. CallerCallTimeout time.Duration // SourceBuilder should builds an unconnected source. SourceBuilder func(zipperAddr, credential string) yomo.Source // ReducerBuilder should builds an unconnected reducer. ReducerBuilder func(zipperAddr, credential string) yomo.StreamFunction // MetadataExchanger exchanges metadata from the credential. MetadataExchanger func(credential string) (metadata.M, error) } ``` 5. Besides,`ServiceOptions` also allows you to modify the default logger, tracer, and the method to get the credential, among other configuration parameters.
1 parent f438306 commit c68e252

File tree

8 files changed

+1288
-1228
lines changed

8 files changed

+1288
-1228
lines changed

pkg/bridge/ai/api_server.go

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ type BasicAPIServer struct {
3535
zipperAddr string
3636
credential string
3737
httpHandler http.Handler
38-
logger *slog.Logger
3938
}
4039

4140
// Serve starts the Basic API Server
@@ -44,19 +43,20 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s
4443
if err != nil {
4544
return err
4645
}
47-
srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger)
46+
srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger)
4847
if err != nil {
4948
return err
5049
}
5150

5251
logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name())
53-
return srv.ServeAddr(config.Server.Addr)
52+
return http.ListenAndServe(config.Server.Addr, srv.httpHandler)
5453
}
5554

56-
func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handler) http.Handler) http.Handler {
55+
// NewServeMux creates a new http.ServeMux for the llm bridge server.
56+
func NewServeMux(service *Service) *http.ServeMux {
5757
var (
58+
h = &Handler{service}
5859
mux = http.NewServeMux()
59-
h = NewHandler(provider)
6060
)
6161
// GET /overview
6262
mux.HandleFunc("/overview", h.HandleOverview)
@@ -65,57 +65,59 @@ func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handle
6565
// POST /v1/chat/completions (OpenAI compatible interface)
6666
mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions)
6767

68-
return decorater(mux)
68+
return mux
69+
}
70+
71+
// DecorateHandler decorates the http.Handler.
72+
func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler {
73+
// decorate the http.Handler
74+
for i := len(decorates) - 1; i >= 0; i-- {
75+
h = decorates[i](h)
76+
}
77+
return h
6978
}
7079

7180
// NewBasicAPIServer creates a new restful service
72-
func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) {
81+
func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) {
7382
zipperAddr = parseZipperAddr(zipperAddr)
7483

75-
cp := NewCallerProvider(zipperAddr, DefaultExchangeMetadataFunc)
84+
logger = logger.With("component", "bridge")
85+
86+
service := NewService(zipperAddr, provider, &ServiceOptions{
87+
Logger: logger,
88+
Tracer: otel.Tracer("yomo-llm-bridge"),
89+
CredentialFunc: func(r *http.Request) (string, error) { return credential, nil },
90+
})
91+
92+
mux := NewServeMux(service)
7693

7794
server := &BasicAPIServer{
7895
zipperAddr: zipperAddr,
7996
credential: credential,
80-
httpHandler: BridgeHTTPHanlder(provider, decorateReqContext(cp, logger, credential)),
81-
logger: logger.With("component", "bridge"),
97+
httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)),
8298
}
8399

84100
return server, nil
85101
}
86102

87-
// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo.
88-
// User can chat to the http server and interact with the YoMo's stream function.
89-
func (a *BasicAPIServer) ServeAddr(addr string) error {
90-
return http.ListenAndServe(addr, a.httpHandler)
91-
}
92-
93-
// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context.
94-
func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler {
95-
tracer := otel.Tracer("yomo-llm-bridge")
96-
97-
caller, err := cp.Provide(credential)
98-
if err != nil {
99-
logger.Info("can't load caller", "err", err)
100-
101-
return func(handler http.Handler) http.Handler {
102-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
103-
w.Header().Set("Content-Type", "application/json")
104-
RespondWithError(w, http.StatusInternalServerError, err)
105-
})
106-
}
107-
}
108-
109-
caller.SetTracer(tracer)
110-
103+
// decorateReqContext decorates the context of the request, it injects a transID into the request's context,
104+
// log the request information and start tracing the request.
105+
func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler {
111106
host, _ := os.Hostname()
112107

113108
return func(handler http.Handler) http.Handler {
114109
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
115110
ctx := r.Context()
116111

112+
caller, err := service.LoadOrCreateCaller(r)
113+
if err != nil {
114+
RespondWithError(w, http.StatusBadRequest, err)
115+
return
116+
}
117+
ctx = WithCallerContext(ctx, caller)
118+
117119
// trace every request
118-
ctx, span := tracer.Start(
120+
ctx, span := service.option.Tracer.Start(
119121
ctx,
120122
r.URL.Path,
121123
trace.WithSpanKind(trace.SpanKindServer),
@@ -125,7 +127,6 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin
125127

126128
transID := id.New(32)
127129
ctx = WithTransIDContext(ctx, transID)
128-
ctx = WithCallerContext(ctx, caller)
129130

130131
logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID)
131132

@@ -136,24 +137,16 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin
136137

137138
// Handler handles the http request.
138139
type Handler struct {
139-
provider provider.LLMProvider
140-
}
141-
142-
// NewHandler returns a new Handler.
143-
func NewHandler(provider provider.LLMProvider) *Handler {
144-
return &Handler{provider}
140+
service *Service
145141
}
146142

147143
// HandleOverview is the handler for GET /overview
148144
func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
149-
caller := FromCallerContext(r.Context())
150-
151145
w.Header().Set("Content-Type", "application/json")
152146

153-
tcs, err := register.ListToolCalls(caller.Metadata())
147+
tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata())
154148
if err != nil {
155-
w.WriteHeader(http.StatusInternalServerError)
156-
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
149+
RespondWithError(w, http.StatusInternalServerError, err)
157150
return
158151
}
159152

@@ -172,7 +165,6 @@ var baseSystemMessage = `You are a very helpful assistant. Your job is to choose
172165
func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
173166
var (
174167
ctx = r.Context()
175-
caller = FromCallerContext(ctx)
176168
transID = FromTransIDContext(ctx)
177169
)
178170
defer r.Body.Close()
@@ -185,14 +177,14 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
185177
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
186178
defer cancel()
187179

188-
res, err := GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack, caller, h.provider)
180+
w.Header().Set("Content-Type", "application/json")
181+
182+
res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
189183
if err != nil {
190-
w.Header().Set("Content-Type", "application/json")
191184
RespondWithError(w, http.StatusInternalServerError, err)
192185
return
193186
}
194187

195-
w.Header().Set("Content-Type", "application/json")
196188
w.WriteHeader(http.StatusOK)
197189
_ = json.NewEncoder(w).Encode(res)
198190
}
@@ -201,7 +193,6 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
201193
func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
202194
var (
203195
ctx = r.Context()
204-
caller = FromCallerContext(ctx)
205196
transID = FromTransIDContext(ctx)
206197
)
207198
defer r.Body.Close()
@@ -214,7 +205,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
214205
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
215206
defer cancel()
216207

217-
if err := GetChatCompletions(ctx, req, transID, h.provider, caller, w); err != nil {
208+
if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
218209
RespondWithError(w, http.StatusBadRequest, err)
219210
return
220211
}
@@ -258,17 +249,17 @@ func getLocalIP() (string, error) {
258249
type callerContextKey struct{}
259250

260251
// WithCallerContext adds the caller to the request context
261-
func WithCallerContext(ctx context.Context, caller Caller) context.Context {
252+
func WithCallerContext(ctx context.Context, caller *Caller) context.Context {
262253
return context.WithValue(ctx, callerContextKey{}, caller)
263254
}
264255

265256
// FromCallerContext returns the caller from the request context
266-
func FromCallerContext(ctx context.Context) Caller {
267-
service, ok := ctx.Value(callerContextKey{}).(Caller)
257+
func FromCallerContext(ctx context.Context) *Caller {
258+
caller, ok := ctx.Value(callerContextKey{}).(*Caller)
268259
if !ok {
269260
return nil
270261
}
271-
return service
262+
return caller
272263
}
273264

274265
type transIDContextKey struct{}

pkg/bridge/ai/api_server_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import (
44
"bytes"
55
"fmt"
66
"io"
7-
"log/slog"
87
"net/http"
98
"net/http/httptest"
109
"testing"
10+
"time"
1111

1212
"github.com/stretchr/testify/assert"
13+
"github.com/yomorun/yomo"
1314
"github.com/yomorun/yomo/ai"
15+
"github.com/yomorun/yomo/core/metadata"
1416
"github.com/yomorun/yomo/pkg/bridge/ai/provider"
1517
"github.com/yomorun/yomo/pkg/bridge/ai/register"
1618
)
@@ -38,11 +40,19 @@ func TestServer(t *testing.T) {
3840
t.Fatal(err)
3941
}
4042

41-
cp := newMockCallerProvider()
43+
flow := newMockDataFlow(newHandler(2 * time.Hour).handle)
4244

43-
cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{})
45+
newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) {
46+
return mockCaller(nil), err
47+
}
48+
49+
service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{
50+
SourceBuilder: func(_, _ string) yomo.Source { return flow },
51+
ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow },
52+
MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },
53+
})
4454

45-
handler := BridgeHTTPHanlder(pd, decorateReqContext(cp, slog.Default(), ""))
55+
handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger))
4656

4757
// create a test server
4858
server := httptest.NewServer(handler)

pkg/bridge/ai/call_syncer.go

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ import (
66
"time"
77

88
openai "github.com/sashabaranov/go-openai"
9-
"github.com/yomorun/yomo"
109
"github.com/yomorun/yomo/ai"
11-
"github.com/yomorun/yomo/serverless"
1210
)
1311

1412
// CallSyncer fires a bunch of function callings, and wait the result of these function callings.
@@ -223,39 +221,3 @@ func (f *callSyncer) background() {
223221
}
224222
}
225223
}
226-
227-
// ToReducer converts a stream function to a reducer that can reduce the function calling result.
228-
func ToReducer(sfn yomo.StreamFunction, logger *slog.Logger, ch chan ReduceMessage) {
229-
// set observe data tags
230-
sfn.SetObserveDataTags(ai.ReducerTag)
231-
// set reduce handler
232-
sfn.SetHandler(func(ctx serverless.Context) {
233-
invoke, err := ctx.LLMFunctionCall()
234-
if err != nil {
235-
ch <- ReduceMessage{ReqID: ""}
236-
logger.Error("parse function calling invoke", "err", err.Error())
237-
return
238-
}
239-
logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result))
240-
241-
message := openai.ChatCompletionMessage{
242-
Role: openai.ChatMessageRoleTool,
243-
Content: invoke.Result,
244-
ToolCallID: invoke.ToolCallID,
245-
}
246-
247-
ch <- ReduceMessage{ReqID: invoke.ReqID, Message: message}
248-
})
249-
}
250-
251-
// ToSource convert a yomo source to the source that can send function calling body to the llm function.
252-
func ToSource(source yomo.Source, logger *slog.Logger, ch chan TagFunctionCall) {
253-
go func() {
254-
for c := range ch {
255-
buf, _ := c.FunctionCall.Bytes()
256-
if err := source.Write(c.Tag, buf); err != nil {
257-
logger.Error("send data to zipper", "err", err.Error())
258-
}
259-
}
260-
}()
261-
}

pkg/bridge/ai/call_syncer_test.go

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,10 @@ func TestTimeoutCallSyncer(t *testing.T) {
2727
flow := newMockDataFlow(h.handle)
2828
defer flow.Close()
2929

30-
reqs := make(chan TagFunctionCall)
31-
ToSource(flow, slog.Default(), reqs)
30+
req, _ := sourceWriteToChan(flow, slog.Default())
31+
res, _ := reduceToChan(flow, slog.Default())
3232

33-
messages := make(chan ReduceMessage)
34-
ToReducer(flow, slog.Default(), messages)
35-
36-
syncer := NewCallSyncer(slog.Default(), reqs, messages, time.Millisecond)
33+
syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond)
3734
go flow.run()
3835

3936
var (
@@ -61,13 +58,10 @@ func TestCallSyncer(t *testing.T) {
6158
flow := newMockDataFlow(h.handle)
6259
defer flow.Close()
6360

64-
reqs := make(chan TagFunctionCall)
65-
ToSource(flow, slog.Default(), reqs)
66-
67-
messages := make(chan ReduceMessage)
68-
ToReducer(flow, slog.Default(), messages)
61+
req, _ := sourceWriteToChan(flow, slog.Default())
62+
res, _ := reduceToChan(flow, slog.Default())
6963

70-
syncer := NewCallSyncer(slog.Default(), reqs, messages, 0)
64+
syncer := NewCallSyncer(slog.Default(), req, res, 0)
7165
go flow.run()
7266

7367
var (
@@ -118,7 +112,7 @@ func (h *handler) result() []openai.ChatCompletionMessage {
118112
return want
119113
}
120114

121-
// mockDataFlow mocks the data flow of ai bridge.
115+
// mockDataFlow mocks the data flow of llm bridge.
122116
// The data flow is: source -> hander -> reducer,
123117
// It is `Write() -> handler() -> reducer()` in this mock implementation.
124118
type mockDataFlow struct {
@@ -160,11 +154,11 @@ var _ yomo.StreamFunction = (*mockDataFlow)(nil)
160154

161155
// The test will not use blowing function in this mock implementation.
162156
func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {}
157+
func (t *mockDataFlow) Connect() error { return nil }
163158
func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") }
164159
func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") }
165160
func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") }
166161
func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") }
167162
func (t *mockDataFlow) Wait() { panic("unimplemented") }
168-
func (t *mockDataFlow) Connect() error { panic("unimplemented") }
169163
func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") }
170164
func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") }

0 commit comments

Comments
 (0)