Skip to content

Commit c70464c

Browse files
jvmncsjimpang
authored andcommitted
multi-LoRA as extra models in OpenAI server (vllm-project#2775)
how to serve the loras (mimicking the [multilora inference example](https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py)): ```terminal $ export LORA_PATH=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ $ python -m vllm.entrypoints.api_server \ --model meta-llama/Llama-2-7b-hf \ --enable-lora \ --lora-modules sql-lora=$LORA_PATH sql-lora2=$LORA_PATH ``` the above server will list 3 separate values if the user queries `/models`: one for the base served model, and one each for the specified lora modules. in this case sql-lora and sql-lora2 point to the same underlying lora, but this need not be the case. lora config values take the same values they do in EngineArgs no work has been done here to scope client permissions to specific models
1 parent aad5c06 commit c70464c

File tree

7 files changed

+200
-27
lines changed

7 files changed

+200
-27
lines changed

docs/source/models/lora.rst

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,43 @@ the third parameter is the path to the LoRA adapter.
4949
5050
5151
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
52-
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
52+
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
53+
54+
Serving LoRA Adapters
55+
---------------------
56+
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
57+
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
58+
59+
.. code-block:: bash
60+
61+
python -m vllm.entrypoints.api_server \
62+
--model meta-llama/Llama-2-7b-hf \
63+
--enable-lora \
64+
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
65+
66+
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
67+
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
68+
with its base model:
69+
70+
.. code-block:: bash
71+
72+
curl localhost:8000/v1/models | jq .
73+
{
74+
"object": "list",
75+
"data": [
76+
{
77+
"id": "meta-llama/Llama-2-7b-hf",
78+
"object": "model",
79+
...
80+
},
81+
{
82+
"id": "sql-lora",
83+
"object": "model",
84+
...
85+
}
86+
]
87+
}
88+
89+
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
90+
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
91+
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).

examples/multilora_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from vllm.lora.request import LoRARequest
1313

1414

15-
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
15+
def create_test_prompts(
16+
lora_path: str
17+
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
1618
"""Create a list of test prompts with their sampling parameters.
1719
1820
2 requests for base model, 4 requests for the LoRA. We define 2

tests/entrypoints/test_openai_server.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import requests
88
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
99
import openai # use the official client for correctness check
10+
from huggingface_hub import snapshot_download # downloading lora to test lora requests
1011

1112
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
1213
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
14+
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
1315

1416
pytestmark = pytest.mark.asyncio
1517

@@ -54,7 +56,12 @@ def __del__(self):
5456

5557

5658
@pytest.fixture(scope="session")
57-
def server():
59+
def zephyr_lora_files():
60+
return snapshot_download(repo_id=LORA_NAME)
61+
62+
63+
@pytest.fixture(scope="session")
64+
def server(zephyr_lora_files):
5865
ray.init()
5966
server_runner = ServerRunner.remote([
6067
"--model",
@@ -64,6 +71,17 @@ def server():
6471
"--max-model-len",
6572
"8192",
6673
"--enforce-eager",
74+
# lora config below
75+
"--enable-lora",
76+
"--lora-modules",
77+
f"zephyr-lora={zephyr_lora_files}",
78+
f"zephyr-lora2={zephyr_lora_files}",
79+
"--max-lora-rank",
80+
"64",
81+
"--max-cpu-loras",
82+
"2",
83+
"--max-num-seqs",
84+
"128"
6785
])
6886
ray.get(server_runner.ready.remote())
6987
yield server_runner
@@ -79,8 +97,25 @@ def client():
7997
yield client
8098

8199

82-
async def test_single_completion(server, client: openai.AsyncOpenAI):
83-
completion = await client.completions.create(model=MODEL_NAME,
100+
async def test_check_models(server, client: openai.AsyncOpenAI):
101+
models = await client.models.list()
102+
models = models.data
103+
served_model = models[0]
104+
lora_models = models[1:]
105+
assert served_model.id == MODEL_NAME
106+
assert all(model.root == MODEL_NAME for model in models)
107+
assert lora_models[0].id == "zephyr-lora"
108+
assert lora_models[1].id == "zephyr-lora2"
109+
110+
111+
@pytest.mark.parametrize(
112+
# first test base model, then test loras
113+
"model_name",
114+
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
115+
)
116+
async def test_single_completion(server, client: openai.AsyncOpenAI,
117+
model_name: str):
118+
completion = await client.completions.create(model=model_name,
84119
prompt="Hello, my name is",
85120
max_tokens=5,
86121
temperature=0.0)
@@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
104139
completion.choices[0].text) >= 5
105140

106141

107-
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
142+
@pytest.mark.parametrize(
143+
# just test 1 lora hereafter
144+
"model_name",
145+
[MODEL_NAME, "zephyr-lora"],
146+
)
147+
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
148+
model_name: str):
108149
messages = [{
109150
"role": "system",
110151
"content": "you are a helpful assistant"
@@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
115156

116157
# test single completion
117158
chat_completion = await client.chat.completions.create(
118-
model=MODEL_NAME,
159+
model=model_name,
119160
messages=messages,
120161
max_tokens=10,
121162
)
@@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
139180
assert message.content is not None and len(message.content) >= 0
140181

141182

142-
async def test_completion_streaming(server, client: openai.AsyncOpenAI):
183+
@pytest.mark.parametrize(
184+
# just test 1 lora hereafter
185+
"model_name",
186+
[MODEL_NAME, "zephyr-lora"],
187+
)
188+
async def test_completion_streaming(server, client: openai.AsyncOpenAI,
189+
model_name: str):
143190
prompt = "What is an LLM?"
144191

145192
single_completion = await client.completions.create(
146-
model=MODEL_NAME,
193+
model=model_name,
147194
prompt=prompt,
148195
max_tokens=5,
149196
temperature=0.0,
@@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
152199
single_usage = single_completion.usage
153200

154201
stream = await client.completions.create(
155-
model=MODEL_NAME,
202+
model=model_name,
156203
prompt=prompt,
157204
max_tokens=5,
158205
temperature=0.0,
@@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
166213
assert "".join(chunks) == single_output
167214

168215

169-
async def test_chat_streaming(server, client: openai.AsyncOpenAI):
216+
@pytest.mark.parametrize(
217+
# just test 1 lora hereafter
218+
"model_name",
219+
[MODEL_NAME, "zephyr-lora"],
220+
)
221+
async def test_chat_streaming(server, client: openai.AsyncOpenAI,
222+
model_name: str):
170223
messages = [{
171224
"role": "system",
172225
"content": "you are a helpful assistant"
@@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
177230

178231
# test single completion
179232
chat_completion = await client.chat.completions.create(
180-
model=MODEL_NAME,
233+
model=model_name,
181234
messages=messages,
182235
max_tokens=10,
183236
temperature=0.0,
@@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
187240

188241
# test streaming
189242
stream = await client.chat.completions.create(
190-
model=MODEL_NAME,
243+
model=model_name,
191244
messages=messages,
192245
max_tokens=10,
193246
temperature=0.0,
@@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
204257
assert "".join(chunks) == output
205258

206259

207-
async def test_batch_completions(server, client: openai.AsyncOpenAI):
260+
@pytest.mark.parametrize(
261+
# just test 1 lora hereafter
262+
"model_name",
263+
[MODEL_NAME, "zephyr-lora"],
264+
)
265+
async def test_batch_completions(server, client: openai.AsyncOpenAI,
266+
model_name: str):
208267
# test simple list
209268
batch = await client.completions.create(
210-
model=MODEL_NAME,
269+
model=model_name,
211270
prompt=["Hello, my name is", "Hello, my name is"],
212271
max_tokens=5,
213272
temperature=0.0,
@@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
217276

218277
# test n = 2
219278
batch = await client.completions.create(
220-
model=MODEL_NAME,
279+
model=model_name,
221280
prompt=["Hello, my name is", "Hello, my name is"],
222281
n=2,
223282
max_tokens=5,
@@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
236295

237296
# test streaming
238297
batch = await client.completions.create(
239-
model=MODEL_NAME,
298+
model=model_name,
240299
prompt=["Hello, my name is", "Hello, my name is"],
241300
max_tokens=5,
242301
temperature=0.0,

vllm/entrypoints/openai/api_server.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.logger import init_logger
2424
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
2525
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
26+
from vllm.entrypoints.openai.serving_engine import LoRA
2627

2728
TIMEOUT_KEEP_ALIVE = 5 # seconds
2829

@@ -48,6 +49,16 @@ async def _force_log():
4849
app = fastapi.FastAPI(lifespan=lifespan)
4950

5051

52+
class LoRAParserAction(argparse.Action):
53+
54+
def __call__(self, parser, namespace, values, option_string=None):
55+
lora_list = []
56+
for item in values:
57+
name, path = item.split('=')
58+
lora_list.append(LoRA(name, path))
59+
setattr(namespace, self.dest, lora_list)
60+
61+
5162
def parse_args():
5263
parser = argparse.ArgumentParser(
5364
description="vLLM OpenAI-Compatible RESTful API server.")
@@ -81,6 +92,15 @@ def parse_args():
8192
help="The model name used in the API. If not "
8293
"specified, the model name will be the same as "
8394
"the huggingface name.")
95+
parser.add_argument(
96+
"--lora-modules",
97+
type=str,
98+
default=None,
99+
nargs='+',
100+
action=LoRAParserAction,
101+
help=
102+
"LoRA module configurations in the format name=path. Multiple modules can be specified."
103+
)
84104
parser.add_argument("--chat-template",
85105
type=str,
86106
default=None,
@@ -217,8 +237,10 @@ async def authentication(request: Request, call_next):
217237
engine = AsyncLLMEngine.from_engine_args(engine_args)
218238
openai_serving_chat = OpenAIServingChat(engine, served_model,
219239
args.response_role,
240+
args.lora_modules,
220241
args.chat_template)
221-
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
242+
openai_serving_completion = OpenAIServingCompletion(
243+
engine, served_model, args.lora_modules)
222244

223245
# Register labels for metrics
224246
add_global_metrics_labels(model_name=engine_args.model)

vllm/entrypoints/openai/serving_chat.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
import codecs
33
from fastapi import Request
4-
from typing import AsyncGenerator, AsyncIterator, Union
4+
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
55
from vllm.logger import init_logger
66
from vllm.utils import random_uuid
77
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -11,7 +11,7 @@
1111
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
1212
UsageInfo)
1313
from vllm.outputs import RequestOutput
14-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
14+
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
1515

1616
logger = init_logger(__name__)
1717

@@ -22,8 +22,11 @@ def __init__(self,
2222
engine: AsyncLLMEngine,
2323
served_model: str,
2424
response_role: str,
25+
lora_modules: Optional[List[LoRA]] = None,
2526
chat_template=None):
26-
super().__init__(engine=engine, served_model=served_model)
27+
super().__init__(engine=engine,
28+
served_model=served_model,
29+
lora_modules=lora_modules)
2730
self.response_role = response_role
2831
self._load_chat_template(chat_template)
2932

@@ -64,11 +67,13 @@ async def create_chat_completion(
6467
token_ids = self._validate_prompt_and_tokenize(request,
6568
prompt=prompt)
6669
sampling_params = request.to_sampling_params()
70+
lora_request = self._maybe_get_lora(request)
6771
except ValueError as e:
6872
return self.create_error_response(str(e))
6973

7074
result_generator = self.engine.generate(prompt, sampling_params,
71-
request_id, token_ids)
75+
request_id, token_ids,
76+
lora_request)
7277
# Streaming response
7378
if request.stream:
7479
return self.chat_completion_stream_generator(

vllm/entrypoints/openai/serving_completion.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
UsageInfo,
1616
)
1717
from vllm.outputs import RequestOutput
18-
from vllm.entrypoints.openai.serving_engine import OpenAIServing
18+
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
1919

2020
logger = init_logger(__name__)
2121

@@ -249,8 +249,13 @@ async def consumer():
249249

250250
class OpenAIServingCompletion(OpenAIServing):
251251

252-
def __init__(self, engine: AsyncLLMEngine, served_model: str):
253-
super().__init__(engine=engine, served_model=served_model)
252+
def __init__(self,
253+
engine: AsyncLLMEngine,
254+
served_model: str,
255+
lora_modules: Optional[List[LoRA]] = None):
256+
super().__init__(engine=engine,
257+
served_model=served_model,
258+
lora_modules=lora_modules)
254259

255260
async def create_completion(self, request: CompletionRequest,
256261
raw_request: Request):
@@ -284,6 +289,7 @@ async def create_completion(self, request: CompletionRequest,
284289
generators = []
285290
try:
286291
sampling_params = request.to_sampling_params()
292+
lora_request = self._maybe_get_lora(request)
287293
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
288294

289295
for i, prompt in enumerate(prompts):
@@ -298,7 +304,8 @@ async def create_completion(self, request: CompletionRequest,
298304
self.engine.generate(None,
299305
sampling_params,
300306
f"{request_id}-{i}",
301-
prompt_token_ids=input_ids))
307+
prompt_token_ids=input_ids,
308+
lora_request=lora_request))
302309
except ValueError as e:
303310
return self.create_error_response(str(e))
304311

0 commit comments

Comments
 (0)