Skip to content
Merged
17 changes: 12 additions & 5 deletions pkg/llm-d-inference-sim/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type configuration struct {
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
// number of inference requests that could be processed at the same time)
MaxNumSeqs int `yaml:"max-num-seqs"`
// MaxModelLen is the model's context window, the maximum number of tokens
// in a single request including input and output. Default value is 1024.
MaxModelLen int `yaml:"max-model-len"`
// LoraModulesString is a list of LoRA adapters as strings
LoraModulesString []string `yaml:"lora-modules"`
// LoraModules is a list of LoRA adapters
Expand Down Expand Up @@ -97,11 +100,12 @@ func (c *configuration) unmarshalLoras() error {

func newConfig() *configuration {
return &configuration{
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
MaxModelLen: 1024,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
}
}

Expand Down Expand Up @@ -151,6 +155,9 @@ func (c *configuration) validate() error {
if c.MaxCPULoras < c.MaxLoras {
return errors.New("max CPU LoRAs cannot be less than max LoRAs")
}
if c.MaxModelLen < 1 {
return errors.New("max model len cannot be less than 1")
}

for _, lora := range c.LoraModules {
if lora.Name == "" {
Expand Down
8 changes: 7 additions & 1 deletion pkg/llm-d-inference-sim/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ var _ = Describe("Simulator configuration", func() {
}
tests = append(tests, test)

test = testCase{
name: "invalid max-model-len",
args: []string{"cmd", "--max-model-len", "0", "--config", "../../manifests/config.yaml"},
}
tests = append(tests, test)

DescribeTable("check configurations",
func(args []string, expectedConfig *configuration) {
config, err := createSimConfig(args)
Expand All @@ -258,6 +264,6 @@ var _ = Describe("Simulator configuration", func() {
Entry(tests[8].name, tests[8].args),
Entry(tests[9].name, tests[9].args),
Entry(tests[10].name, tests[10].args),
Entry(tests[11].name, tests[11].args),
Entry(tests[11].name, tests[11].args),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge is incorrect, there are 11 tests in main, and you added another one, so there should be an additional line with
Entry(tests[12].name, tests[12].args)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this one, thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

)
})
13 changes: 13 additions & 0 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type completionRequest interface {
getTools() []tool
// getToolChoice() returns tool choice (in chat completion)
getToolChoice() string
// getMaxCompletionTokens returns the maximum completion tokens requested
getMaxCompletionTokens() *int64
}

// baseCompletionRequest contains base completion request related information
Expand Down Expand Up @@ -143,6 +145,13 @@ func (c *chatCompletionRequest) getToolChoice() string {
return c.ToolChoice
}

func (c *chatCompletionRequest) getMaxCompletionTokens() *int64 {
if c.MaxCompletionTokens != nil {
return c.MaxCompletionTokens
}
return c.MaxTokens
}

// getLastUserMsg returns last message from this request's messages with user role,
// if does not exist - returns an empty string
func (req *chatCompletionRequest) getLastUserMsg() string {
Expand Down Expand Up @@ -202,6 +211,10 @@ func (c *textCompletionRequest) getToolChoice() string {
return ""
}

func (c *textCompletionRequest) getMaxCompletionTokens() *int64 {
return c.MaxTokens
}

// createResponseText creates and returns response payload based on this request,
// i.e., an array of generated tokens, the finish reason, and the number of created
// tokens
Expand Down
10 changes: 9 additions & 1 deletion pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")

f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
Expand Down Expand Up @@ -372,6 +373,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
return
}

// Validate context window constraints
err = validateContextWindow(vllmReq.getNumberOfPromptTokens(), vllmReq.getMaxCompletionTokens(), s.config.MaxModelLen)
if err != nil {
s.sendCompletionError(ctx, err.Error(), "BadRequestError", fasthttp.StatusBadRequest)
return
}

var wg sync.WaitGroup
wg.Add(1)
reqCtx := &completionReqCtx{
Expand Down Expand Up @@ -531,7 +539,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
} else {
ctx.SetContentType("application/json")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetStatusCode(code)
ctx.SetBody(data)
}
}
Expand Down
101 changes: 101 additions & 0 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,105 @@ var _ = Describe("Simulator", func() {
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

Context("max-model-len context window validation", func() {
It("Should reject requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP to verify the error response format
reqBody := `{
"messages": [{"role": "user", "content": "This is a test message"}],
"model": "my_model",
"max_tokens": 8
}`

resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("However, you requested 13 tokens"))
Expect(string(body)).To(ContainSubstring("5 in the messages, 8 in the completion"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))

// Also test with OpenAI client to ensure it gets an error
openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("This is a test message"),
},
Model: model,
MaxTokens: openai.Int(8),
})

Expect(err).To(HaveOccurred())
var apiErr *openai.Error
Expect(errors.As(err, &apiErr)).To(BeTrue())
Expect(apiErr.StatusCode).To(Equal(400))
})

It("Should accept requests within context window", func() {
ctx := context.TODO()
// Start server with max-model-len=50
args := []string{"cmd", "--model", model, "--mode", modeEcho, "--max-model-len", "50"}
client, err := startServerWithArgs(ctx, modeEcho, args)
Expect(err).NotTo(HaveOccurred())

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

// Send a request within the context window
resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
Model: model,
MaxTokens: openai.Int(5),
})

Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).To(HaveLen(1))
Expect(resp.Model).To(Equal(model))
})

It("Should handle text completion requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP for text completion
reqBody := `{
"prompt": "This is a long test prompt with many words",
"model": "my_model",
"max_tokens": 5
}`

resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))
})
})
})
17 changes: 17 additions & 0 deletions pkg/llm-d-inference-sim/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
return tokens, nil
}

// validateContextWindow checks if the request fits within the model's context window
// Returns an error if prompt tokens + max completion tokens exceeds the max model length
func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) error {
completionTokens := int64(0)
if maxCompletionTokens != nil {
completionTokens = *maxCompletionTokens
}

totalTokens := int64(promptTokens) + completionTokens
if totalTokens > int64(maxModelLen) {
return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion.",
maxModelLen, totalTokens, promptTokens, completionTokens)
}

return nil
}

// getRandomResponseText returns random response text from the pre-defined list of responses
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {
Expand Down
41 changes: 41 additions & 0 deletions pkg/llm-d-inference-sim/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,45 @@ var _ = Describe("Utils", func() {
Expect(finishReason).Should(Equal(stopFinishReason))
})
})

Context("validateContextWindow", func() {
It("should pass when total tokens are within limit", func() {
promptTokens := 100
maxCompletionTokens := int64(50)
maxModelLen := 200

err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(err).Should(BeNil())
})

It("should fail when total tokens exceed limit", func() {
promptTokens := 150
maxCompletionTokens := int64(100)
maxModelLen := 200

err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(err).ShouldNot(BeNil())
Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens"))
Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens"))
Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion"))
})

It("should handle nil max completion tokens", func() {
promptTokens := 100
maxModelLen := 200

err := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(err).Should(BeNil())
})

It("should fail when only prompt tokens exceed limit", func() {
promptTokens := 250
maxModelLen := 200

err := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(err).ShouldNot(BeNil())
Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens"))
Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens"))
})
})
})