Skip to content

Commit 0e7bbb4

Browse files
[feat] Add real amd mock endpoints for Query and Chat completion (#4543)
This commit contains the following: 1. Mock and real endpoints for query. 2. Allows query when no DB is present. 3. Allows query when DB is present. Signed-off-by: Varsha Prasad Narsing <[email protected]> Co-authored-by: Matias Schimuneck <[email protected]>
1 parent d52e2e7 commit 0e7bbb4

File tree

6 files changed

+392
-2
lines changed

6 files changed

+392
-2
lines changed

frontend/packages/llama-stack-modular-ui/bff/internal/api/app.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ const (
3030

3131
// making it simpler than /tool-runtime/rag-tool/insert
3232
UploadPath = ApiPathPrefix + "/upload"
33+
QueryPath = ApiPathPrefix + "/query"
3334
)
3435

3536
type App struct {
@@ -103,6 +104,7 @@ func (app *App) Routes() http.Handler {
103104
// POST to register the vectorDB (/v1/vector-dbs)
104105
apiRouter.POST(VectorDBListPath, app.RequireAuthRoute(app.AttachRESTClient(app.RegisterVectorDBHandler)))
105106
apiRouter.POST(UploadPath, app.RequireAuthRoute(app.AttachRESTClient(app.UploadHandler)))
107+
apiRouter.POST(QueryPath, app.RequireAuthRoute(app.AttachRESTClient(app.QueryHandler)))
106108

107109
// App Router
108110
appMux := http.NewServeMux()

frontend/packages/llama-stack-modular-ui/bff/internal/api/middleware.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ package api
33
import (
44
"context"
55
"fmt"
6-
"github.com/julienschmidt/httprouter"
7-
"github.com/opendatahub-io/llama-stack-modular-ui/internal/integrations"
86
"log/slog"
97
"net/http"
108
"runtime/debug"
119

10+
"github.com/julienschmidt/httprouter"
11+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/integrations"
12+
1213
"github.com/google/uuid"
1314
"github.com/opendatahub-io/llama-stack-modular-ui/internal/constants"
1415
helper "github.com/opendatahub-io/llama-stack-modular-ui/internal/helpers"
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
package api
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
9+
"github.com/julienschmidt/httprouter"
10+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/constants"
11+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/integrations"
12+
"github.com/opendatahub-io/llama-stack-modular-ui/internal/integrations/llamastack"
13+
)
14+
15+
// QueryRequest represents the request body for querying documents
16+
type QueryRequest struct {
17+
Content string `json:"content"`
18+
VectorDBIDs []string `json:"vector_db_ids,omitempty"`
19+
QueryConfig llamastack.QueryConfigParam `json:"query_config,omitempty"`
20+
// Chat completion options (LLM model for generating responses)
21+
LLMModelID string `json:"llm_model_id"`
22+
SamplingParams *llamastack.SamplingParams `json:"sampling_params,omitempty"`
23+
}
24+
25+
func (app *App) QueryHandler(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
26+
client, ok := r.Context().Value(constants.LlamaStackHttpClientKey).(integrations.HTTPClientInterface)
27+
28+
if !ok {
29+
app.serverErrorResponse(w, r, errors.New("REST client not found"))
30+
return
31+
}
32+
33+
// Parse the request body
34+
var queryRequest QueryRequest
35+
if err := json.NewDecoder(r.Body).Decode(&queryRequest); err != nil {
36+
app.badRequestResponse(w, r, err)
37+
return
38+
}
39+
40+
// Validate required fields
41+
if queryRequest.Content == "" {
42+
app.badRequestResponse(w, r, errors.New("content is required"))
43+
return
44+
}
45+
if queryRequest.LLMModelID == "" {
46+
app.badRequestResponse(w, r, errors.New("llm_model_id is required"))
47+
return
48+
}
49+
50+
// Check if we should perform RAG query
51+
hasVectorDBs := len(queryRequest.VectorDBIDs) > 0
52+
var response llamastack.QueryEmbeddingModelResponse
53+
var hasRAGContent bool
54+
55+
if hasVectorDBs {
56+
// Create default query configuration if not provided
57+
queryConfig := queryRequest.QueryConfig
58+
if queryConfig.MaxChunks == 0 {
59+
queryConfig.MaxChunks = 5 // Default value
60+
}
61+
if queryConfig.MaxTokensInContext == 0 {
62+
queryConfig.MaxTokensInContext = 1000 // Default value
63+
}
64+
if queryConfig.ChunkTemplate == "" {
65+
queryConfig.ChunkTemplate = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" // Default template
66+
}
67+
68+
// Create the query embedding model request
69+
queryEmbeddingModelRequest := llamastack.QueryEmbeddingModelRequest{
70+
Content: queryRequest.Content,
71+
VectorDBIDs: queryRequest.VectorDBIDs,
72+
QueryConfig: queryConfig,
73+
}
74+
75+
// Query the embedding model
76+
ragResponse, err := app.repositories.LlamaStackClient.QueryEmbeddingModel(client, queryEmbeddingModelRequest)
77+
if err != nil {
78+
app.serverErrorResponse(w, r, err)
79+
return
80+
}
81+
response = ragResponse
82+
83+
// Check if we have RAG content
84+
for _, contentItem := range response.Content {
85+
if contentItem.Type == "text" {
86+
hasRAGContent = true
87+
break
88+
}
89+
}
90+
}
91+
92+
// Extract text content from the query response
93+
var contextText string
94+
for _, contentItem := range response.Content {
95+
if contentItem.Type == "text" {
96+
contextText += contentItem.Text + "\n"
97+
}
98+
}
99+
100+
// Create messages for chat completion
101+
var messages []llamastack.ChatMessage
102+
103+
if hasRAGContent {
104+
// Use RAG context if available
105+
messages = []llamastack.ChatMessage{
106+
{
107+
Role: "system",
108+
Content: "You are a helpful assistant that explains concepts based on the provided context. Use the context to answer the user's question accurately and concisely.",
109+
},
110+
{
111+
Role: "user",
112+
Content: fmt.Sprintf("Based on this context, answer the following question:\n\nContext:\n%s\n\nQuestion: %s", contextText, queryRequest.Content),
113+
},
114+
}
115+
} else {
116+
// No RAG content available, provide a general response
117+
messages = []llamastack.ChatMessage{
118+
{
119+
Role: "system",
120+
Content: "You are a helpful assistant that explains concepts. If you don't have specific context about the topic, provide a general explanation based on your knowledge.",
121+
},
122+
{
123+
Role: "user",
124+
Content: fmt.Sprintf("Please explain: %s", queryRequest.Content),
125+
},
126+
}
127+
}
128+
129+
// Use provided sampling params or defaults
130+
samplingParams := llamastack.SamplingParams{
131+
Strategy: llamastack.SamplingStrategy{
132+
Type: "greedy",
133+
},
134+
MaxTokens: 500,
135+
}
136+
if queryRequest.SamplingParams != nil {
137+
samplingParams = *queryRequest.SamplingParams
138+
}
139+
140+
// Create chat completion request
141+
chatCompletionRequest := llamastack.ChatCompletionRequest{
142+
ModelID: queryRequest.LLMModelID,
143+
Messages: messages,
144+
SamplingParams: samplingParams,
145+
}
146+
147+
// Generate chat completion
148+
chatResponse, err := app.repositories.LlamaStackClient.ChatCompletion(client, chatCompletionRequest)
149+
if err != nil {
150+
app.serverErrorResponse(w, r, err)
151+
return
152+
}
153+
154+
// Return consistent response format regardless of RAG usage
155+
combinedResponse := map[string]interface{}{
156+
"rag_response": response,
157+
"chat_completion": chatResponse,
158+
"has_rag_content": hasRAGContent,
159+
"used_vector_dbs": hasVectorDBs,
160+
"assistant_message": chatResponse.CompletionMessage.Content,
161+
}
162+
163+
w.WriteHeader(http.StatusOK)
164+
if err := json.NewEncoder(w).Encode(combinedResponse); err != nil {
165+
app.logger.Error("Failed to encode response", "error", err)
166+
}
167+
}

frontend/packages/llama-stack-modular-ui/bff/internal/integrations/llamastack/llamastack_datatypes.go

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,93 @@ type DocumentInsertRequest struct {
5050
VectorDBID string `json:"vector_db_id"`
5151
ChunkSizeInTokens *int `json:"chunk_size_in_tokens,omitempty"`
5252
}
53+
54+
// QueryEmbeddingModelRequest represents the request body for querying an embedding model
55+
type QueryEmbeddingModelRequest struct {
56+
// A image content item
57+
Content string `json:"content"`
58+
VectorDBIDs []string `json:"vector_db_ids"`
59+
// Configuration for the RAG query generation.
60+
QueryConfig QueryConfigParam `json:"query_config"`
61+
}
62+
63+
type QueryConfigParam struct {
64+
// Template for formatting each retrieved chunk in the context. Available
65+
// placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content
66+
// string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
67+
// {chunk.content}\nMetadata: {metadata}\n"
68+
ChunkTemplate string `json:"chunk_template"`
69+
// Maximum number of chunks to retrieve.
70+
MaxChunks int64 `json:"max_chunks"`
71+
// Maximum number of tokens in the context.
72+
MaxTokensInContext int64 `json:"max_tokens_in_context"`
73+
}
74+
75+
type QueryEmbeddingModelResponse struct {
76+
Content []ContentItem `json:"content"`
77+
Metadata Metadata `json:"metadata"`
78+
}
79+
80+
type ContentItem struct {
81+
Type string `json:"type"`
82+
Text string `json:"text"`
83+
}
84+
85+
type Metadata struct {
86+
DocumentIDs []string `json:"document_ids"`
87+
Chunks []string `json:"chunks"`
88+
Scores []float64 `json:"scores"`
89+
}
90+
91+
// Chat completion types
92+
type ChatCompletionRequest struct {
93+
ModelID string `json:"model_id"`
94+
Messages []ChatMessage `json:"messages"`
95+
SamplingParams SamplingParams `json:"sampling_params"`
96+
}
97+
98+
type ChatMessage struct {
99+
Role string `json:"role"`
100+
Content string `json:"content"`
101+
}
102+
103+
type SamplingParams struct {
104+
Strategy SamplingStrategy `json:"strategy"`
105+
MaxTokens int64 `json:"max_tokens"`
106+
}
107+
108+
type SamplingStrategy struct {
109+
Type string `json:"type"`
110+
}
111+
112+
type ChatCompletionResponse struct {
113+
Metrics []Metric `json:"metrics"`
114+
CompletionMessage CompletionMessage `json:"completion_message"`
115+
Logprobs interface{} `json:"logprobs"`
116+
}
117+
118+
type Metric struct {
119+
Metric string `json:"metric"`
120+
Value interface{} `json:"value"`
121+
Unit interface{} `json:"unit"`
122+
}
123+
124+
type CompletionMessage struct {
125+
Role string `json:"role"`
126+
Content string `json:"content"`
127+
StopReason string `json:"stop_reason"`
128+
ToolCalls []interface{} `json:"tool_calls"`
129+
}
130+
131+
// Legacy types for backward compatibility (used in mock)
132+
type ChatChoice struct {
133+
Index int `json:"index"`
134+
Message ChatMessage `json:"message"`
135+
FinishReason string `json:"finish_reason"`
136+
}
137+
138+
type Usage struct {
139+
PromptTokens int `json:"prompt_tokens"`
140+
CompletionTokens int `json:"completion_tokens"`
141+
TotalTokens int `json:"total_tokens"`
142+
}

frontend/packages/llama-stack-modular-ui/bff/internal/mocks/llamastack_client_mock.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,91 @@ func (l *LlamastackClientMock) InsertDocuments(_ integrations.HTTPClientInterfac
160160

161161
return nil
162162
}
163+
164+
func (l *LlamastackClientMock) QueryEmbeddingModel(_ integrations.HTTPClientInterface, request llamastack.QueryEmbeddingModelRequest) (llamastack.QueryEmbeddingModelResponse, error) {
165+
l.mutex.Lock()
166+
defer l.mutex.Unlock()
167+
168+
// Validate request
169+
if request.Content == "" {
170+
return llamastack.QueryEmbeddingModelResponse{}, fmt.Errorf("content is required")
171+
}
172+
173+
if len(request.VectorDBIDs) == 0 {
174+
return llamastack.QueryEmbeddingModelResponse{}, fmt.Errorf("at least one vector_db_id is required")
175+
}
176+
177+
// Simulate successful query response
178+
response := llamastack.QueryEmbeddingModelResponse{
179+
Content: []llamastack.ContentItem{
180+
{
181+
Type: "text",
182+
Text: fmt.Sprintf("Mock response for query: %s", request.Content),
183+
},
184+
{
185+
Type: "text",
186+
Text: "Additional mock content from vector database",
187+
},
188+
{
189+
Type: "text",
190+
Text: "More relevant content based on the query",
191+
},
192+
},
193+
Metadata: llamastack.Metadata{
194+
DocumentIDs: []string{"mock-doc-001", "mock-doc-002"},
195+
Chunks: []string{"Mock chunk 1", "Mock chunk 2"},
196+
Scores: []float64{0.95, 0.87},
197+
},
198+
}
199+
200+
fmt.Printf("Mock: Query executed successfully for content: %s\n", request.Content)
201+
fmt.Printf("Mock: Searched in vector databases: %v\n", request.VectorDBIDs)
202+
fmt.Printf("Mock: Returning %d content items\n", len(response.Content))
203+
204+
return response, nil
205+
}
206+
207+
func (l *LlamastackClientMock) ChatCompletion(_ integrations.HTTPClientInterface, request llamastack.ChatCompletionRequest) (llamastack.ChatCompletionResponse, error) {
208+
l.mutex.Lock()
209+
defer l.mutex.Unlock()
210+
211+
// Validate request
212+
if request.ModelID == "" {
213+
return llamastack.ChatCompletionResponse{}, fmt.Errorf("model_id is required")
214+
}
215+
216+
if len(request.Messages) == 0 {
217+
return llamastack.ChatCompletionResponse{}, fmt.Errorf("messages are required")
218+
}
219+
220+
// Get the last user message for context
221+
var lastUserMessage string
222+
for i := len(request.Messages) - 1; i >= 0; i-- {
223+
if request.Messages[i].Role == "user" {
224+
lastUserMessage = request.Messages[i].Content
225+
break
226+
}
227+
}
228+
229+
// Simulate successful chat completion response
230+
response := llamastack.ChatCompletionResponse{
231+
Metrics: []llamastack.Metric{
232+
{Metric: "prompt_tokens", Value: 50, Unit: nil},
233+
{Metric: "completion_tokens", Value: 25, Unit: nil},
234+
{Metric: "total_tokens", Value: 75, Unit: nil},
235+
},
236+
CompletionMessage: llamastack.CompletionMessage{
237+
Role: "assistant",
238+
Content: fmt.Sprintf("Mock response to: %s. This is a simulated chat completion response based on the provided context and messages.", lastUserMessage),
239+
StopReason: "stop",
240+
ToolCalls: []interface{}{},
241+
},
242+
Logprobs: nil,
243+
}
244+
245+
fmt.Printf("Mock: Chat completion executed successfully for model: %s\n", request.ModelID)
246+
fmt.Printf("Mock: Processed %d messages\n", len(request.Messages))
247+
fmt.Printf("Mock: Returning response with completion message\n")
248+
249+
return response, nil
250+
}

0 commit comments

Comments
 (0)