Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def get_prompt(self) -> str:
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, images = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
Expand Down Expand Up @@ -261,6 +263,16 @@ def get_prompt(self) -> str:
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def get_images(self):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
for image in msg[1]:
images.append(image)

return images

def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message
Expand Down
15 changes: 15 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,20 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("solar")


class LlavaAdapter(BaseModelAdapter):
"""The model adapter for liuhaotian/llava-v1.5 series of models"""

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
# TODO(chris): Implement huggingface-compatible load_model
pass

def match(self, model_path: str):
return "llava" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")


class YuanAdapter(BaseModelAdapter):
"""The model adapter for Yuan"""

Expand Down Expand Up @@ -2305,6 +2319,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(MetaMathAdapter)
register_model_adapter(BagelAdapter)
register_model_adapter(SolarAdapter)
register_model_adapter(LlavaAdapter)
register_model_adapter(YuanAdapter)

# After all adapters, try the default base adapter.
Expand Down
6 changes: 5 additions & 1 deletion fastchat/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ class LogProbs(BaseModel):

class ChatCompletionRequest(BaseModel):
model: str
messages: Union[str, List[Dict[str, str]]]
messages: Union[
str,
List[Dict[str, str]],
List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]],
]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
top_k: Optional[int] = -1
Expand Down
22 changes: 21 additions & 1 deletion fastchat/serve/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,13 +286,29 @@ async def get_gen_params(

if isinstance(messages, str):
prompt = messages
images = []
else:
for message in messages:
msg_role = message["role"]
if msg_role == "system":
conv.set_system_message(message["content"])
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
if type(message["content"]) == list:
image_list = [
item["image_url"]["url"]
for item in message["content"]
if item["type"] == "image_url"
]
text_list = [
item["text"]
for item in message["content"]
if item["type"] == "text"
]

text = "\n".join(text_list)
conv.append_message(conv.roles[0], (text, image_list))
else:
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
Expand All @@ -301,6 +317,7 @@ async def get_gen_params(
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
images = conv.get_images()

gen_params = {
"model": model_name,
Expand All @@ -316,6 +333,9 @@ async def get_gen_params(
"stop_token_ids": conv.stop_token_ids,
}

if len(images) > 0:
gen_params["images"] = images

if best_of is not None:
gen_params.update({"best_of": best_of})
if use_beam_search is not None:
Expand Down
Loading