Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions plugins/wasm-go/extensions/ai-security-guard/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,12 @@ func onHttpRequestBody(ctx wrapper.HttpContext, config AISecurityConfig, body []
}
ctx.DontReadResponseBody()
config.incrementCounter("ai_sec_request_deny", 1)
proxywasm.ResumeHttpRequest()
ctx.SetUserAttribute("safecheck_status", "request deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
Expand Down Expand Up @@ -385,6 +390,11 @@ func onHttpResponseHeaders(ctx wrapper.HttpContext, config AISecurityConfig, log
return types.ActionContinue
}
hdsMap := convertHeaders(headers)
if !strings.Contains(strings.Join(hdsMap[":status"], ";"), "200") {
log.Debugf("response is not 200, skip response body check")
ctx.DontReadResponseBody()
return types.ActionContinue
}
ctx.SetContext("headers", hdsMap)
return types.HeaderStopIteration
}
Expand Down Expand Up @@ -436,22 +446,24 @@ func onHttpResponseBody(ctx wrapper.HttpContext, config AISecurityConfig, body [
denyMessage = response.Data.Advice[0].Answer
}
marshalledDenyMessage := marshalStr(denyMessage, log)
var jsonData []byte
if config.protocolOriginal {
jsonData = []byte(marshalledDenyMessage)
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, []byte(marshalledDenyMessage), -1)
} else if isStreamingResponse {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
jsonData := []byte(fmt.Sprintf(OpenAIStreamResponseFormat, randomID, model, marshalledDenyMessage, randomID, model))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "text/event-stream;charset=UTF-8"}}, jsonData, -1)
} else {
randomID := generateRandomID()
jsonData = []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
jsonData := []byte(fmt.Sprintf(OpenAIResponseFormat, randomID, model, marshalledDenyMessage))
proxywasm.SendHttpResponse(uint32(config.denyCode), [][2]string{{"content-type", "application/json"}}, jsonData, -1)
}
delete(hdsMap, "content-length")
hdsMap[":status"] = []string{fmt.Sprint(config.denyCode)}
proxywasm.ReplaceHttpResponseHeaders(reconvertHeaders(hdsMap))
proxywasm.ReplaceHttpResponseBody(jsonData)
config.incrementCounter("ai_sec_response_deny", 1)
proxywasm.ResumeHttpResponse()
ctx.SetUserAttribute("safecheck_status", "response deny")
if response.Data.Advice != nil {
ctx.SetUserAttribute("safecheck_riskLabel", response.Data.Result[0].Label)
ctx.SetUserAttribute("safecheck_riskWords", response.Data.Result[0].RiskWords)
}
ctx.WriteUserAttributeToLogWithKey(wrapper.AILogKey)
}
singleCall = func() {
timestamp := time.Now().UTC().Format("2006-01-02T15:04:05Z")
Expand Down
Loading