Skip to content
Merged
17 changes: 12 additions & 5 deletions pkg/llm-d-inference-sim/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type configuration struct {
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
// number of inference requests that could be processed at the same time)
MaxNumSeqs int `yaml:"max-num-seqs"`
// MaxModelLen is the model's context window, the maximum number of tokens
// in a single request including input and output. Default value is 1024.
MaxModelLen int `yaml:"max-model-len"`
// LoraModulesString is a list of LoRA adapters as strings
LoraModulesString []string `yaml:"lora-modules"`
// LoraModules is a list of LoRA adapters
Expand Down Expand Up @@ -97,11 +100,12 @@ func (c *configuration) unmarshalLoras() error {

func newConfig() *configuration {
return &configuration{
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
MaxModelLen: 1024,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
}
}

Expand Down Expand Up @@ -151,6 +155,9 @@ func (c *configuration) validate() error {
if c.MaxCPULoras < c.MaxLoras {
return errors.New("max CPU LoRAs cannot be less than max LoRAs")
}
if c.MaxModelLen < 1 {
return errors.New("max model len cannot be less than 1")
}

for _, lora := range c.LoraModules {
if lora.Name == "" {
Expand Down
7 changes: 7 additions & 0 deletions pkg/llm-d-inference-sim/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ var _ = Describe("Simulator configuration", func() {
}
tests = append(tests, test)

test = testCase{
name: "invalid max-model-len",
args: []string{"cmd", "--max-model-len", "0", "--config", "../../manifests/config.yaml"},
}
tests = append(tests, test)

DescribeTable("check configurations",
func(args []string, expectedConfig *configuration) {
config, err := createSimConfig(args)
Expand All @@ -217,5 +223,6 @@ var _ = Describe("Simulator configuration", func() {
Entry(tests[7].name, tests[7].args),
Entry(tests[8].name, tests[8].args),
Entry(tests[9].name, tests[9].args),
Entry(tests[10].name, tests[10].args),
)
})
13 changes: 13 additions & 0 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type completionRequest interface {
getTools() []tool
// getToolChoice() returns tool choice (in chat completion)
getToolChoice() string
// getMaxCompletionTokens returns the maximum completion tokens requested
getMaxCompletionTokens() *int64
}

// baseCompletionRequest contains base completion request related information
Expand Down Expand Up @@ -143,6 +145,13 @@ func (c *chatCompletionRequest) getToolChoice() string {
return c.ToolChoice
}

func (c *chatCompletionRequest) getMaxCompletionTokens() *int64 {
if c.MaxCompletionTokens != nil {
return c.MaxCompletionTokens
}
return c.MaxTokens
}

// getLastUserMsg returns last message from this request's messages with user role,
// if does not exist - returns an empty string
func (req *chatCompletionRequest) getLastUserMsg() string {
Expand Down Expand Up @@ -202,6 +211,10 @@ func (c *textCompletionRequest) getToolChoice() string {
return ""
}

func (c *textCompletionRequest) getMaxCompletionTokens() *int64 {
return c.MaxTokens
}

// createResponseText creates and returns response payload based on this request,
// i.e., an array of generated tokens, the finish reason, and the number of created
// tokens
Expand Down
10 changes: 9 additions & 1 deletion pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")

f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
Expand Down Expand Up @@ -368,6 +369,13 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
return
}

// Validate context window constraints
err = validateContextWindow(vllmReq.getNumberOfPromptTokens(), vllmReq.getMaxCompletionTokens(), s.config.MaxModelLen)
if err != nil {
s.sendCompletionError(ctx, err.Error(), "BadRequestError", fasthttp.StatusBadRequest)
return
}

var wg sync.WaitGroup
wg.Add(1)
reqCtx := &completionReqCtx{
Expand Down Expand Up @@ -527,7 +535,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
} else {
ctx.SetContentType("application/json")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetStatusCode(code)
ctx.SetBody(data)
}
}
Expand Down
101 changes: 101 additions & 0 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,105 @@ var _ = Describe("Simulator", func() {
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

Context("max-model-len context window validation", func() {
It("Should reject requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP to verify the error response format
reqBody := `{
"messages": [{"role": "user", "content": "This is a test message"}],
"model": "my_model",
"max_tokens": 8
}`

resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("However, you requested 13 tokens"))
Expect(string(body)).To(ContainSubstring("5 in the messages, 8 in the completion"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))

// Also test with OpenAI client to ensure it gets an error
openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("This is a test message"),
},
Model: model,
MaxTokens: openai.Int(8),
})

Expect(err).To(HaveOccurred())
var apiErr *openai.Error
Expect(errors.As(err, &apiErr)).To(BeTrue())
Expect(apiErr.StatusCode).To(Equal(400))
})

It("Should accept requests within context window", func() {
ctx := context.TODO()
// Start server with max-model-len=50
args := []string{"cmd", "--model", model, "--mode", modeEcho, "--max-model-len", "50"}
client, err := startServerWithArgs(ctx, modeEcho, args)
Expect(err).NotTo(HaveOccurred())

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

// Send a request within the context window
resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
Model: model,
MaxTokens: openai.Int(5),
})

Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).To(HaveLen(1))
Expect(resp.Model).To(Equal(model))
})

It("Should handle text completion requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP for text completion
reqBody := `{
"prompt": "This is a long test prompt with many words",
"model": "my_model",
"max_tokens": 5
}`

resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))
})
})
})
21 changes: 21 additions & 0 deletions pkg/llm-d-inference-sim/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
return tokens, nil
}

// validateContextWindow checks if the request fits within the model's context window
// Returns an error if prompt tokens + max completion tokens exceeds the max model length
func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) error {
if maxModelLen <= 0 {
return nil // no limit configured
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that maxModelLen is <= 0? You added a check in configuration validate()
if c.MaxModelLen < 1 { return errors.New("max model len cannot be less than 1") }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @irar2 , you're right. This is redundant. Removing this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, Can you please check now?


completionTokens := int64(0)
if maxCompletionTokens != nil {
completionTokens = *maxCompletionTokens
}

totalTokens := int64(promptTokens) + completionTokens
if totalTokens > int64(maxModelLen) {
return fmt.Errorf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion.",
maxModelLen, totalTokens, promptTokens, completionTokens)
}

return nil
}

// getRandomResponseText returns random response text from the pre-defined list of responses
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {
Expand Down
50 changes: 50 additions & 0 deletions pkg/llm-d-inference-sim/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,54 @@ var _ = Describe("Utils", func() {
Expect(finishReason).Should(Equal(stopFinishReason))
})
})

Context("validateContextWindow", func() {
It("should pass when total tokens are within limit", func() {
promptTokens := 100
maxCompletionTokens := int64(50)
maxModelLen := 200

err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(err).Should(BeNil())
})

It("should fail when total tokens exceed limit", func() {
promptTokens := 150
maxCompletionTokens := int64(100)
maxModelLen := 200

err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(err).ShouldNot(BeNil())
Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens"))
Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens"))
Expect(err.Error()).Should(ContainSubstring("150 in the messages, 100 in the completion"))
})

It("should pass when no max model length is configured", func() {
promptTokens := 1000
maxCompletionTokens := int64(1000)
maxModelLen := 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this.


err := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(err).Should(BeNil())
})

It("should handle nil max completion tokens", func() {
promptTokens := 100
maxModelLen := 200

err := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(err).Should(BeNil())
})

It("should fail when only prompt tokens exceed limit", func() {
promptTokens := 250
maxModelLen := 200

err := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(err).ShouldNot(BeNil())
Expect(err.Error()).Should(ContainSubstring("This model's maximum context length is 200 tokens"))
Expect(err.Error()).Should(ContainSubstring("However, you requested 250 tokens"))
})
})
})