Skip to content

multi-LoRA as extra models in OpenAI server #2775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
919827a
use lora request in OpenAIServiing; expose lora modules as parameter …
jvmncs Feb 5, 2024
a4881f9
add single completion test cases for lora adapters
jvmncs Feb 5, 2024
613c77a
bugfix
jvmncs Feb 6, 2024
19694ae
switch back to zephyr base
jvmncs Feb 6, 2024
53c096e
add extra lora test case to all openai completion tests
jvmncs Feb 6, 2024
b8ffada
[Minor] Fix benchmark_latency script (#2765)
WoosukKwon Feb 5, 2024
3faac81
[ROCm] Fix some kernels failed unit tests (#2498)
hongxiayang Feb 5, 2024
e6f0009
Set local logging level via env variable (#2774)
gardberg Feb 5, 2024
419b31d
[ROCm] Fixup arch checks for ROCM (#2627)
dllehr-amd Feb 5, 2024
36017aa
Add fused top-K softmax kernel for MoE (#2769)
WoosukKwon Feb 6, 2024
78dee0a
modelscope: fix issue when model parameter is not a model id but path…
liuyhwangyh Feb 6, 2024
585846d
[Minor] More fix of test_cache.py CI test failure (#2750)
LiuXiaoxuanPKU Feb 6, 2024
5cb2c3a
[ROCm] Fix build problem resulted from previous commit related to FP8…
hongxiayang Feb 7, 2024
e1152b1
Add documentation on how to do incremental builds (#2796)
pcmoritz Feb 7, 2024
593578c
[Ray] Integration compiled DAG off by default (#2471)
rkooo567 Feb 8, 2024
5c40715
Disable custom all reduce by default (#2808)
WoosukKwon Feb 8, 2024
5d228c1
[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-atte…
hongxiayang Feb 11, 2024
2090924
Add documentation section about LoRA (#2834)
pcmoritz Feb 12, 2024
b440270
Refactor 2 awq gemm kernels into m16nXk32 (#2723)
zcnrex Feb 12, 2024
57b02d0
Serving Benchmark Refactoring (#2433)
ywang96 Feb 13, 2024
1f6c168
[CI] Ensure documentation build is checked in CI (#2842)
simon-mo Feb 13, 2024
9a2cbe1
Refactor llama family models (#2637)
esmeetu Feb 13, 2024
822b463
Revert "Refactor llama family models (#2637)" (#2851)
pcmoritz Feb 13, 2024
299b8cc
Use CuPy for CUDA graphs (#2811)
WoosukKwon Feb 13, 2024
44b28d2
Remove Yi model definition, please use `LlamaForCausalLM` instead (#2…
pcmoritz Feb 13, 2024
0525d72
Add LoRA support for Mixtral (#2831)
tterrysun Feb 13, 2024
70aa7d4
Migrate InternLMForCausalLM to LlamaForCausalLM (#2860)
pcmoritz Feb 14, 2024
8c3d97a
Fix internlm after https://github.com/vllm-project/vllm/pull/2860 (#2…
pcmoritz Feb 14, 2024
5d34102
[Fix] Fix memory profiling when GPU is used by multiple processes (#2…
WoosukKwon Feb 14, 2024
35952e4
append lora serving instructions to lora documentation
jvmncs Feb 14, 2024
a6deb54
Merge branch 'main' into openai-lora
jvmncs Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/multilora_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from vllm.lora.request import LoRARequest


def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.

2 requests for base model, 4 requests for the LoRA. We define 2
Expand Down
41 changes: 38 additions & 3 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests

MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here

pytestmark = pytest.mark.asyncio

Expand Down Expand Up @@ -54,7 +56,12 @@ def __del__(self):


@pytest.fixture(scope="session")
def server():
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)


@pytest.fixture(scope="session")
def server(zephyr_lora_files):
ray.init()
server_runner = ServerRunner.remote([
"--model",
Expand All @@ -64,6 +71,17 @@ def server():
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
])
ray.get(server_runner.ready.remote())
yield server_runner
Expand All @@ -79,8 +97,25 @@ def client():
yield client


async def test_single_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"


@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason this test was failing for all cases when I switched to MODEL_NAME="mistralai/Mistral-7B-v0.1". model was consistently emitting "1999" no matter the prompt/temperature I tried. not really sure why that's the case but reverting to the zephyr model fixed it

prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
Expand Down
24 changes: 23 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA

TIMEOUT_KEEP_ALIVE = 5 # seconds

Expand All @@ -48,6 +49,16 @@ async def _force_log():
app = fastapi.FastAPI(lifespan=lifespan)


class LoRAParserAction(argparse.Action):

def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)


def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
Expand Down Expand Up @@ -81,6 +92,15 @@ def parse_args():
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template",
type=str,
default=None,
Expand Down Expand Up @@ -217,8 +237,10 @@ async def authentication(request: Request, call_next):
engine = AsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)

# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)
Expand Down
13 changes: 9 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand All @@ -11,7 +11,7 @@
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA

logger = init_logger(__name__)

Expand All @@ -22,8 +22,11 @@ def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
super().__init__(engine=engine, served_model=served_model)
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)

Expand Down Expand Up @@ -64,11 +67,13 @@ async def create_chat_completion(
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
except ValueError as e:
return self.create_error_response(str(e))

result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids)
request_id, token_ids,
lora_request)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
Expand Down
15 changes: 11 additions & 4 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
UsageInfo,
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA

logger = init_logger(__name__)

Expand Down Expand Up @@ -249,8 +249,13 @@ async def consumer():

class OpenAIServingCompletion(OpenAIServing):

def __init__(self, engine: AsyncLLMEngine, served_model: str):
super().__init__(engine=engine, served_model=served_model)
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)

async def create_completion(self, request: CompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -284,6 +289,7 @@ async def create_completion(self, request: CompletionRequest,
generators = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)

for i, prompt in enumerate(prompts):
Expand All @@ -298,7 +304,8 @@ async def create_completion(self, request: CompletionRequest,
self.engine.generate(None,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids))
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
return self.create_error_response(str(e))

Expand Down
41 changes: 40 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
Expand All @@ -9,15 +10,35 @@
ErrorResponse, LogProbs,
ModelCard, ModelList,
ModelPermission)
from vllm.lora.request import LoRARequest

logger = init_logger(__name__)


@dataclass
class LoRA:
name: str
local_path: str


class OpenAIServing:

def __init__(self, engine: AsyncLLMEngine, served_model: str):
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules=Optional[List[LoRA]]):
self.engine = engine
self.served_model = served_model
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]

self.max_model_len = 0
self.tokenizer = None
Expand Down Expand Up @@ -50,6 +71,13 @@ async def show_available_models(self) -> ModelList:
root=self.served_model,
permission=[ModelPermission()])
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model,
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)

def _create_logprobs(
Expand Down Expand Up @@ -99,11 +127,22 @@ def create_error_response(
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model:
return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)

def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
Expand Down