Skip to content

Commit 7025c0e

Browse files
committed
fix: fix body rewrite error & Fix non chat type model mapping
1 parent 849a53f commit 7025c0e

File tree

9 files changed

+77
-46
lines changed

9 files changed

+77
-46
lines changed

plugins/wasm-go/extensions/ai-proxy/main.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,6 @@ func getOpenAiApiName(path string) provider.ApiName {
274274
if strings.HasSuffix(path, "/v1/embeddings") {
275275
return provider.ApiNameEmbeddings
276276
}
277-
if strings.HasSuffix(path, "/v1/audio/transcriptions") {
278-
return provider.ApiNameAudioTranscription
279-
}
280277
if strings.HasSuffix(path, "/v1/audio/speech") {
281278
return provider.ApiNameAudioSpeech
282279
}

plugins/wasm-go/extensions/ai-proxy/provider/claude.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ func (c *claudeProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName,
139139
}
140140

141141
func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
142+
if apiName != ApiNameChatCompletion {
143+
return c.config.defaultTransformRequestBody(ctx, apiName, body, log)
144+
}
142145
request := &chatCompletionRequest{}
143146
if err := c.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
144147
return nil, err
@@ -148,6 +151,9 @@ func (c *claudeProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName A
148151
}
149152

150153
func (c *claudeProvider) TransformResponseBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
154+
if apiName != ApiNameChatCompletion {
155+
return body, nil
156+
}
151157
claudeResponse := &claudeTextGenResponse{}
152158
if err := json.Unmarshal(body, claudeResponse); err != nil {
153159
return nil, fmt.Errorf("unable to unmarshal claude response: %v", err)

plugins/wasm-go/extensions/ai-proxy/provider/cohere.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiNam
107107
}
108108

109109
func (m *cohereProvider) TransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
110+
if apiName != ApiNameChatCompletion {
111+
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
112+
}
110113
request := &chatCompletionRequest{}
111114
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {
112115
return nil, err

plugins/wasm-go/extensions/ai-proxy/provider/dify.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ import (
44
"encoding/json"
55
"errors"
66
"fmt"
7+
"net/http"
8+
"strings"
9+
"time"
10+
711
"github.com/alibaba/higress/plugins/wasm-go/extensions/ai-proxy/util"
812
"github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper"
913
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
1014
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
11-
"net/http"
12-
"strings"
13-
"time"
1415
)
1516

1617
const (
@@ -83,6 +84,9 @@ func (d *difyProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName, b
8384
}
8485

8586
func (d *difyProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
87+
if apiName != ApiNameChatCompletion {
88+
return d.config.defaultTransformRequestBody(ctx, apiName, body, log)
89+
}
8690
request := &chatCompletionRequest{}
8791
err := d.config.parseRequestAndMapModel(ctx, request, body, log)
8892
if err != nil {

plugins/wasm-go/extensions/ai-proxy/provider/hunyuan.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ func (m *hunyuanProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiName
294294
// hunyuan 的 TransformRequestBodyHeaders 方法只在 failover 健康检查的时候会调用
295295
func (m *hunyuanProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
296296
if m.useOpenAICompatibleAPI() {
297-
return body, nil
297+
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
298298
}
299299
request := &chatCompletionRequest{}
300300
err := m.config.parseRequestAndMapModel(ctx, request, body, log)

plugins/wasm-go/extensions/ai-proxy/provider/model.go

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,21 @@ const (
1919
)
2020

2121
type chatCompletionRequest struct {
22-
Model string `json:"model"`
23-
Messages []chatMessage `json:"messages"`
24-
MaxTokens int `json:"max_tokens,omitempty"`
25-
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
26-
N int `json:"n,omitempty"`
27-
PresencePenalty float64 `json:"presence_penalty,omitempty"`
28-
Seed int `json:"seed,omitempty"`
29-
Stream bool `json:"stream,omitempty"`
30-
StreamOptions *streamOptions `json:"stream_options,omitempty"`
31-
Temperature float64 `json:"temperature,omitempty"`
32-
TopP float64 `json:"top_p,omitempty"`
33-
Tools []tool `json:"tools,omitempty"`
34-
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
35-
User string `json:"user,omitempty"`
36-
Stop []string `json:"stop,omitempty"`
22+
Model string `json:"model"`
23+
Messages []chatMessage `json:"messages"`
24+
MaxTokens int `json:"max_tokens,omitempty"`
25+
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
26+
N int `json:"n,omitempty"`
27+
PresencePenalty float64 `json:"presence_penalty,omitempty"`
28+
Seed int `json:"seed,omitempty"`
29+
Stream bool `json:"stream,omitempty"`
30+
StreamOptions *streamOptions `json:"stream_options,omitempty"`
31+
Temperature float64 `json:"temperature,omitempty"`
32+
TopP float64 `json:"top_p,omitempty"`
33+
Tools []tool `json:"tools,omitempty"`
34+
ToolChoice *toolChoice `json:"tool_choice,omitempty"`
35+
User string `json:"user,omitempty"`
36+
Stop []string `json:"stop,omitempty"`
3737
ResponseFormat map[string]interface{} `json:"response_format,omitempty"`
3838
}
3939

@@ -230,6 +230,21 @@ func (e *streamEvent) setValue(key, value string) {
230230
}
231231
}
232232

233+
// https://platform.openai.com/docs/guides/images
234+
type imageGenerationRequest struct {
235+
Model string `json:"model"`
236+
Prompt string `json:"prompt"`
237+
N int `json:"n,omitempty"`
238+
Size string `json:"size,omitempty"`
239+
}
240+
241+
// https://platform.openai.com/docs/guides/speech-to-text
242+
type audioSpeechRequest struct {
243+
Model string `json:"model"`
244+
Input string `json:"input"`
245+
Voice string `json:"voice"`
246+
}
247+
233248
type embeddingsRequest struct {
234249
Input interface{} `json:"input"`
235250
Model string `json:"model"`

plugins/wasm-go/extensions/ai-proxy/provider/moonshot.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ func (m *moonshotProvider) OnRequestBody(ctx wrapper.HttpContext, apiName ApiNam
8484
if !m.config.isSupportedAPI(apiName) {
8585
return types.ActionContinue, errUnsupportedApiName
8686
}
87+
// 非chat类型的请求,不做处理
88+
if apiName != ApiNameChatCompletion {
89+
return types.ActionContinue, nil
90+
}
8791

8892
request := &chatCompletionRequest{}
8993
if err := m.config.parseRequestAndMapModel(ctx, request, body, log); err != nil {

plugins/wasm-go/extensions/ai-proxy/provider/provider.go

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package provider
22

33
import (
4-
"encoding/json"
54
"errors"
65
"math/rand"
76
"net/http"
@@ -12,6 +11,7 @@ import (
1211
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm"
1312
"github.com/higress-group/proxy-wasm-go-sdk/proxywasm/types"
1413
"github.com/tidwall/gjson"
14+
"github.com/tidwall/sjson"
1515
)
1616

1717
type ApiName string
@@ -22,11 +22,10 @@ const (
2222
// ApiName 格式 {vendor}/{version}/{apitype}
2323
// 表示遵循 厂商/版本/接口类型 的格式
2424
// 目前openai是事实意义上的标准,但是也有其他厂商存在其他任务的一些可能的标准,比如cohere的rerank
25-
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
26-
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
27-
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
28-
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
29-
ApiNameAudioTranscription ApiName = "openai/v1/audiotranscription"
25+
ApiNameChatCompletion ApiName = "openai/v1/chatcompletions"
26+
ApiNameEmbeddings ApiName = "openai/v1/embeddings"
27+
ApiNameImageGeneration ApiName = "openai/v1/imagegeneration"
28+
ApiNameAudioSpeech ApiName = "openai/v1/audiospeech"
3029

3130
PathOpenAIChatCompletions = "/v1/chat/completions"
3231
PathOpenAIEmbeddings = "/v1/embeddings"
@@ -379,23 +378,19 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
379378
c.outputVariable = json.Get("outputVariable").String()
380379

381380
c.capabilities = make(map[string]string)
382-
for capability, pathJson := range json.Get("abilities").Map() {
381+
for capability, pathJson := range json.Get("capabilities").Map() {
383382
// 过滤掉不受支持的能力
384383
switch capability {
385384
case string(ApiNameChatCompletion),
386385
string(ApiNameEmbeddings),
387386
string(ApiNameImageGeneration),
388-
string(ApiNameAudioSpeech),
389-
string(ApiNameAudioTranscription):
387+
string(ApiNameAudioSpeech):
390388
c.capabilities[capability] = pathJson.String()
391389
}
392390
}
393391
}
394392

395393
func (c *ProviderConfig) Validate() error {
396-
if c.timeout < 0 {
397-
return errors.New("invalid timeout in config")
398-
}
399394
if c.protocol != protocolOpenAI && c.protocol != protocolOriginal {
400395
return errors.New("invalid protocol in config")
401396
}
@@ -528,7 +523,7 @@ func getMappedModel(model string, modelMapping map[string]string, log wrapper.Lo
528523
}
529524

530525
func doGetMappedModel(model string, modelMapping map[string]string, log wrapper.Log) string {
531-
if modelMapping == nil || len(modelMapping) == 0 {
526+
if len(modelMapping) == 0 {
532527
return ""
533528
}
534529

@@ -618,17 +613,21 @@ func (c *ProviderConfig) handleRequestHeaders(provider Provider, ctx wrapper.Htt
618613
}
619614
}
620615

616+
// defaultTransformRequestBody 默认的请求体转换方法,只做模型映射,用slog替换模型名称,不用序列化和反序列化,提高性能
621617
func (c *ProviderConfig) defaultTransformRequestBody(ctx wrapper.HttpContext, apiName ApiName, body []byte, log wrapper.Log) ([]byte, error) {
622-
var request interface{}
623-
if apiName == ApiNameChatCompletion {
624-
request = &chatCompletionRequest{}
625-
} else {
626-
request = &embeddingsRequest{}
627-
}
628-
if err := c.parseRequestAndMapModel(ctx, request, body, log); err != nil {
629-
return nil, err
618+
switch apiName {
619+
case ApiNameChatCompletion:
620+
stream := gjson.GetBytes(body, "stream").Bool()
621+
if stream {
622+
_ = proxywasm.ReplaceHttpRequestHeader("Accept", "text/event-stream")
623+
ctx.SetContext(ctxKeyIsStreaming, true)
624+
} else {
625+
ctx.SetContext(ctxKeyIsStreaming, false)
626+
}
630627
}
631-
return json.Marshal(request)
628+
model := gjson.GetBytes(body, "model").String()
629+
ctx.SetContext(ctxKeyOriginalRequestModel, model)
630+
return sjson.SetBytes(body, "model", getMappedModel(model, c.modelMapping, log))
632631
}
633632

634633
func (c *ProviderConfig) DefaultTransformResponseHeaders(ctx wrapper.HttpContext, headers http.Header) {

plugins/wasm-go/extensions/ai-proxy/provider/qwen.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,13 @@ func (m *qwenProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName
8989
}
9090

9191
func (m *qwenProvider) TransformRequestBodyHeaders(ctx wrapper.HttpContext, apiName ApiName, body []byte, headers http.Header, log wrapper.Log) ([]byte, error) {
92-
if apiName == ApiNameChatCompletion {
92+
switch apiName {
93+
case ApiNameChatCompletion:
9394
return m.onChatCompletionRequestBody(ctx, body, headers, log)
94-
} else {
95+
case ApiNameEmbeddings:
9596
return m.onEmbeddingsRequestBody(ctx, body, log)
97+
default:
98+
return m.config.defaultTransformRequestBody(ctx, apiName, body, log)
9699
}
97100
}
98101

0 commit comments

Comments
 (0)