Skip to content

Commit 98b8f64

Browse files
Gradio Web Server for Multimodal Models (#2960)
Co-authored-by: Lianmin Zheng <[email protected]>
1 parent b9d4d15 commit 98b8f64

15 files changed

+528
-108
lines changed

fastchat/conversation.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
66
"""
77

8+
import base64
89
import dataclasses
910
from enum import auto, IntEnum
11+
from io import BytesIO
1012
from typing import List, Any, Dict, Union, Tuple
1113

1214

@@ -34,6 +36,9 @@ class SeparatorStyle(IntEnum):
3436
YUAN2 = auto()
3537

3638

39+
IMAGE_PLACEHOLDER_STR = "$$<image>$$"
40+
41+
3742
@dataclasses.dataclass
3843
class Conversation:
3944
"""A class that manages prompt templates and keeps all conversation history."""
@@ -47,6 +52,7 @@ class Conversation:
4752
# The names of two roles
4853
roles: Tuple[str] = ("USER", "ASSISTANT")
4954
# All messages. Each item is (role, message).
55+
# Each message is either a string or a tuple of (string, List[image_url]).
5056
messages: List[List[str]] = ()
5157
# The number of few shot examples
5258
offset: int = 0
@@ -77,6 +83,7 @@ def get_prompt(self) -> str:
7783
if message:
7884
if type(message) is tuple:
7985
message, images = message
86+
message = IMAGE_PLACEHOLDER_STR * len(images) + message
8087
ret += role + ": " + message + seps[i % 2]
8188
else:
8289
ret += role + ":"
@@ -289,11 +296,52 @@ def update_last_message(self, message: str):
289296
"""
290297
self.messages[-1][1] = message
291298

299+
def convert_image_to_base64(self, image):
300+
"""Given an image, return the base64 encoded image string."""
301+
from PIL import Image
302+
import requests
303+
304+
# Load image if it has not been loaded in yet
305+
if type(image) == str:
306+
if image.startswith("http://") or image.startswith("https://"):
307+
response = requests.get(image)
308+
image = Image.open(BytesIO(response.content)).convert("RGB")
309+
elif "base64" in image:
310+
# OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str}
311+
return image.split(",")[1]
312+
else:
313+
image = Image.open(image).convert("RGB")
314+
315+
max_hw, min_hw = max(image.size), min(image.size)
316+
aspect_ratio = max_hw / min_hw
317+
max_len, min_len = 2048, 2048
318+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
319+
longest_edge = int(shortest_edge * aspect_ratio)
320+
W, H = image.size
321+
if longest_edge != max(image.size):
322+
if H > W:
323+
H, W = longest_edge, shortest_edge
324+
else:
325+
H, W = shortest_edge, longest_edge
326+
image = image.resize((W, H))
327+
328+
buffered = BytesIO()
329+
image.save(buffered, format="PNG")
330+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
331+
332+
return img_b64_str
333+
292334
def to_gradio_chatbot(self):
293335
"""Convert the conversation to gradio chatbot format."""
294336
ret = []
295337
for i, (role, msg) in enumerate(self.messages[self.offset :]):
296338
if i % 2 == 0:
339+
if type(msg) is tuple:
340+
msg, image = msg
341+
img_b64_str = image[0] # Only one image on gradio at one time
342+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
343+
msg = img_str + msg.replace("<image>\n", "").strip()
344+
297345
ret.append([msg, None])
298346
else:
299347
ret[-1][-1] = msg
@@ -314,6 +362,12 @@ def to_openai_api_messages(self):
314362
ret.append({"role": "assistant", "content": msg})
315363
return ret
316364

365+
def extract_text_from_messages(self):
366+
return [
367+
(role, message[0]) if type(message) is tuple else (role, message)
368+
for role, message in self.messages
369+
]
370+
317371
def copy(self):
318372
return Conversation(
319373
name=self.name,
@@ -334,7 +388,7 @@ def dict(self):
334388
"template_name": self.name,
335389
"system_message": self.system_message,
336390
"roles": self.roles,
337-
"messages": self.messages,
391+
"messages": self.extract_text_from_messages(),
338392
"offset": self.offset,
339393
}
340394

0 commit comments

Comments
 (0)