Skip to content
Merged
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
- `lora-modules`: a list of LoRA adapters (a list of space-separated JSON strings): '{"name": "name", "path": "lora_path", "base_model_name": "id"}', optional, empty by default
- `max-loras`: maximum number of LoRAs in a single batch, optional, default is one
- `max-cpu-loras`: maximum number of LoRAs to store in CPU memory, optional, must be >= than max-loras, default is max-loras
- `max-model-len`: model's context window, maximum number of tokens in a single request including input and output, optional, default is 1024
- `max-num-seqs`: maximum number of sequences per iteration (maximum number of inference requests that could be processed at the same time), default is 5
- `mode`: the simulator mode, optional, by default `random`
- `echo`: returns the same text that was sent in the request
Expand Down
17 changes: 12 additions & 5 deletions pkg/llm-d-inference-sim/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type configuration struct {
// MaxNumSeqs is maximum number of sequences per iteration (the maximum
// number of inference requests that could be processed at the same time)
MaxNumSeqs int `yaml:"max-num-seqs"`
// MaxModelLen is the model's context window, the maximum number of tokens
// in a single request including input and output. Default value is 1024.
MaxModelLen int `yaml:"max-model-len"`
// LoraModulesString is a list of LoRA adapters as strings
LoraModulesString []string `yaml:"lora-modules"`
// LoraModules is a list of LoRA adapters
Expand Down Expand Up @@ -97,11 +100,12 @@ func (c *configuration) unmarshalLoras() error {

func newConfig() *configuration {
return &configuration{
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
Port: vLLMDefaultPort,
MaxLoras: 1,
MaxNumSeqs: 5,
MaxModelLen: 1024,
Mode: modeRandom,
Seed: time.Now().UnixNano(),
}
}

Expand Down Expand Up @@ -151,6 +155,9 @@ func (c *configuration) validate() error {
if c.MaxCPULoras < c.MaxLoras {
return errors.New("max CPU LoRAs cannot be less than max LoRAs")
}
if c.MaxModelLen < 1 {
return errors.New("max model len cannot be less than 1")
}

for _, lora := range c.LoraModules {
if lora.Name == "" {
Expand Down
7 changes: 7 additions & 0 deletions pkg/llm-d-inference-sim/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ var _ = Describe("Simulator configuration", func() {
}
tests = append(tests, test)

test = testCase{
name: "invalid max-model-len",
args: []string{"cmd", "--max-model-len", "0", "--config", "../../manifests/config.yaml"},
}
tests = append(tests, test)

DescribeTable("check configurations",
func(args []string, expectedConfig *configuration) {
config, err := createSimConfig(args)
Expand All @@ -259,5 +265,6 @@ var _ = Describe("Simulator configuration", func() {
Entry(tests[9].name, tests[9].args),
Entry(tests[10].name, tests[10].args),
Entry(tests[11].name, tests[11].args),
Entry(tests[12].name, tests[12].args),
)
})
13 changes: 13 additions & 0 deletions pkg/llm-d-inference-sim/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ type completionRequest interface {
getTools() []tool
// getToolChoice() returns tool choice (in chat completion)
getToolChoice() string
// getMaxCompletionTokens returns the maximum completion tokens requested
getMaxCompletionTokens() *int64
}

// baseCompletionRequest contains base completion request related information
Expand Down Expand Up @@ -143,6 +145,13 @@ func (c *chatCompletionRequest) getToolChoice() string {
return c.ToolChoice
}

func (c *chatCompletionRequest) getMaxCompletionTokens() *int64 {
if c.MaxCompletionTokens != nil {
return c.MaxCompletionTokens
}
return c.MaxTokens
}

// getLastUserMsg returns last message from this request's messages with user role,
// if does not exist - returns an empty string
func (req *chatCompletionRequest) getLastUserMsg() string {
Expand Down Expand Up @@ -202,6 +211,10 @@ func (c *textCompletionRequest) getToolChoice() string {
return ""
}

func (c *textCompletionRequest) getMaxCompletionTokens() *int64 {
return c.MaxTokens
}

// createResponseText creates and returns response payload based on this request,
// i.e., an array of generated tokens, the finish reason, and the number of created
// tokens
Expand Down
13 changes: 12 additions & 1 deletion pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (s *VllmSimulator) parseCommandParamsAndLoadConfig() error {
f.IntVar(&config.MaxNumSeqs, "max-num-seqs", config.MaxNumSeqs, "Maximum number of inference requests that could be processed at the same time (parameter to simulate requests waiting queue)")
f.IntVar(&config.MaxLoras, "max-loras", config.MaxLoras, "Maximum number of LoRAs in a single batch")
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")

f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
Expand Down Expand Up @@ -372,6 +373,16 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple
return
}

// Validate context window constraints
promptTokens := vllmReq.getNumberOfPromptTokens()
completionTokens := vllmReq.getMaxCompletionTokens()
isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen)
if !isValid {
s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion",
s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest)
return
}

var wg sync.WaitGroup
wg.Add(1)
reqCtx := &completionReqCtx{
Expand Down Expand Up @@ -531,7 +542,7 @@ func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string
ctx.Error(err.Error(), fasthttp.StatusInternalServerError)
} else {
ctx.SetContentType("application/json")
ctx.SetStatusCode(fasthttp.StatusNotFound)
ctx.SetStatusCode(code)
ctx.SetBody(data)
}
}
Expand Down
107 changes: 107 additions & 0 deletions pkg/llm-d-inference-sim/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,111 @@ var _ = Describe("Simulator", func() {
Expect(err).NotTo(HaveOccurred())
Expect(resp.StatusCode).To(Equal(http.StatusOK))
})

Context("max-model-len context window validation", func() {
It("Should reject requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP to verify the error response format
reqBody := `{
"messages": [{"role": "user", "content": "This is a test message"}],
"model": "my_model",
"max_tokens": 8
}`

resp, err := client.Post("http://localhost/v1/chat/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer func() {
err := resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
}()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("However, you requested 13 tokens"))
Expect(string(body)).To(ContainSubstring("5 in the messages, 8 in the completion"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))

// Also test with OpenAI client to ensure it gets an error
openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("This is a test message"),
},
Model: model,
MaxTokens: openai.Int(8),
})

Expect(err).To(HaveOccurred())
var apiErr *openai.Error
Expect(errors.As(err, &apiErr)).To(BeTrue())
Expect(apiErr.StatusCode).To(Equal(400))
})

It("Should accept requests within context window", func() {
ctx := context.TODO()
// Start server with max-model-len=50
args := []string{"cmd", "--model", model, "--mode", modeEcho, "--max-model-len", "50"}
client, err := startServerWithArgs(ctx, modeEcho, args)
Expect(err).NotTo(HaveOccurred())

openaiclient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

// Send a request within the context window
resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
Model: model,
MaxTokens: openai.Int(5),
})

Expect(err).NotTo(HaveOccurred())
Expect(resp.Choices).To(HaveLen(1))
Expect(resp.Model).To(Equal(model))
})

It("Should handle text completion requests exceeding context window", func() {
ctx := context.TODO()
// Start server with max-model-len=10
args := []string{"cmd", "--model", model, "--mode", modeRandom, "--max-model-len", "10"}
client, err := startServerWithArgs(ctx, modeRandom, args)
Expect(err).NotTo(HaveOccurred())

// Test with raw HTTP for text completion
reqBody := `{
"prompt": "This is a long test prompt with many words",
"model": "my_model",
"max_tokens": 5
}`

resp, err := client.Post("http://localhost/v1/completions", "application/json", strings.NewReader(reqBody))
Expect(err).NotTo(HaveOccurred())
defer func() {
err := resp.Body.Close()
Expect(err).NotTo(HaveOccurred())
}()

body, err := io.ReadAll(resp.Body)
Expect(err).NotTo(HaveOccurred())

Expect(resp.StatusCode).To(Equal(400))
Expect(string(body)).To(ContainSubstring("This model's maximum context length is 10 tokens"))
Expect(string(body)).To(ContainSubstring("BadRequestError"))
})
})
})
14 changes: 14 additions & 0 deletions pkg/llm-d-inference-sim/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ func getMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error)
return tokens, nil
}

// validateContextWindow checks if the request fits within the model's context window
// Returns validation result, actual completion tokens, and total tokens
func validateContextWindow(promptTokens int, maxCompletionTokens *int64, maxModelLen int) (bool, int64, int64) {
completionTokens := int64(0)
if maxCompletionTokens != nil {
completionTokens = *maxCompletionTokens
}

totalTokens := int64(promptTokens) + completionTokens
isValid := totalTokens <= int64(maxModelLen)

return isValid, completionTokens, totalTokens
}

// getRandomResponseText returns random response text from the pre-defined list of responses
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
func getRandomResponseText(maxCompletionTokens *int64) (string, string) {
Expand Down
44 changes: 44 additions & 0 deletions pkg/llm-d-inference-sim/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,48 @@ var _ = Describe("Utils", func() {
Expect(finishReason).Should(Equal(stopFinishReason))
})
})

Context("validateContextWindow", func() {
It("should pass when total tokens are within limit", func() {
promptTokens := 100
maxCompletionTokens := int64(50)
maxModelLen := 200

isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(isValid).Should(BeTrue())
Expect(actualCompletionTokens).Should(Equal(int64(50)))
Expect(totalTokens).Should(Equal(int64(150)))
})

It("should fail when total tokens exceed limit", func() {
promptTokens := 150
maxCompletionTokens := int64(100)
maxModelLen := 200

isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, &maxCompletionTokens, maxModelLen)
Expect(isValid).Should(BeFalse())
Expect(actualCompletionTokens).Should(Equal(int64(100)))
Expect(totalTokens).Should(Equal(int64(250)))
})

It("should handle nil max completion tokens", func() {
promptTokens := 100
maxModelLen := 200

isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(isValid).Should(BeTrue())
Expect(actualCompletionTokens).Should(Equal(int64(0)))
Expect(totalTokens).Should(Equal(int64(100)))
})

It("should fail when only prompt tokens exceed limit", func() {
promptTokens := 250
maxModelLen := 200

isValid, actualCompletionTokens, totalTokens := validateContextWindow(promptTokens, nil, maxModelLen)
Expect(isValid).Should(BeFalse())
Expect(actualCompletionTokens).Should(Equal(int64(0)))
Expect(totalTokens).Should(Equal(int64(250)))
})
})
})